core.py 59.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, TypeVar, cast
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.envs import enable_envs_cache
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.cache import engine_receiver_cache_from_config
27
from vllm.tasks import POOLING_TASKS, SupportedTask
28
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
29
30
31
32
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
33
from vllm.utils.hashing import get_hash_fn_by_name
34
from vllm.utils.network_utils import make_zmq_socket
35
from vllm.utils.system_utils import decorate_logs, set_process_title
36
37
38
39
40
41
42
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
43
from vllm.v1.core.sched.interface import SchedulerInterface
44
from vllm.v1.core.sched.output import SchedulerOutput
45
from vllm.v1.engine import (
46
    EngineCoreOutput,
47
48
49
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
50
    FinishReason,
51
52
53
54
55
56
57
58
59
60
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
61
from vllm.v1.executor import Executor
62
from vllm.v1.kv_cache_interface import KVCacheConfig
63
from vllm.v1.metrics.stats import SchedulerStats
64
from vllm.v1.outputs import ModelRunnerOutput
65
from vllm.v1.request import Request, RequestStatus
66
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
67
from vllm.v1.structured_output import StructuredOutputManager
68
69
70
71
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

72
POLLING_TIMEOUT_S = 2.5
73
HANDSHAKE_TIMEOUT_MINS = 5
74

75
_R = TypeVar("_R")  # Return type for collective_rpc
76

77
78
79
80

class EngineCore:
    """Inner loop of vLLM's Engine."""

81
82
83
84
85
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
86
        executor_fail_callback: Callable | None = None,
87
    ):
88
89
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
90

91
92
        load_general_plugins()

93
        self.vllm_config = vllm_config
94
        if vllm_config.parallel_config.data_parallel_rank == 0:
95
96
97
98
99
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
100

101
102
        self.log_stats = log_stats

103
104
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
105
        if executor_fail_callback is not None:
106
            self.model_executor.register_failure_callback(executor_fail_callback)
107

108
109
        self.available_gpu_memory_for_kv_cache = -1

110
        # Setup KV Caches and update CacheConfig after profiling.
111
112
113
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
114

115
116
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
117
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
118

119
120
        self.structured_output_manager = StructuredOutputManager(vllm_config)

121
        # Setup scheduler.
122
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
123

124
        if len(kv_cache_config.kv_cache_groups) == 0:  # noqa: SIM102
125
126
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
127
128
129
            if vllm_config.scheduler_config.enable_chunked_prefill:
                logger.warning("Disabling chunked prefill for model without KVCache")
                vllm_config.scheduler_config.enable_chunked_prefill = False
130

131
132
133
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
134
            * vllm_config.parallel_config.prefill_context_parallel_size
135
136
        )

137
        self.scheduler: SchedulerInterface = Scheduler(
138
            vllm_config=vllm_config,
139
140
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
141
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
142
            log_stats=self.log_stats,
143
            block_size=scheduler_block_size,
144
        )
145
        self.use_spec_decode = vllm_config.speculative_config is not None
146
        if self.scheduler.connector is not None:  # type: ignore
147
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
148

149
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
150
        self.mm_receiver_cache = engine_receiver_cache_from_config(
151
152
            vllm_config, mm_registry
        )
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

175
176
177
178
179
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
180
181
182
        self.batch_queue: (
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
        ) = None
183
        if self.batch_queue_size > 1:
184
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
185
            self.batch_queue = deque(maxlen=self.batch_queue_size)
186

187
        self.is_ec_producer = (
188
189
190
            vllm_config.ec_transfer_config is not None
            and vllm_config.ec_transfer_config.is_ec_producer
        )
191
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
192

193
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
194
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
195
            caching_hash_fn = get_hash_fn_by_name(
196
197
                vllm_config.cache_config.prefix_caching_hash_algo
            )
198
199
200
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
201
                scheduler_block_size, caching_hash_fn
202
            )
203

204
205
206
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
207
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
208

209
210
        self.aborts_queue = queue.Queue[list[str]]()

211
212
213
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()
214
215
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()
216
217
218
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()
219

220
    def _initialize_kv_caches(
221
222
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
223
        start = time.time()
224

225
        # Get all kv cache needed by the model
226
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
227

228
229
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
230
231
232
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
233
                self.available_gpu_memory_for_kv_cache = (
234
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
235
236
237
238
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
239
240
241
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
242
243
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
244
245
246
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
247

248
        assert len(kv_cache_specs) == len(available_gpu_memory)
249

250
251
252
        # Track max_model_len before KV cache config to detect auto-fit changes
        max_model_len_before = vllm_config.model_config.max_model_len

253
254
255
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
256
257
258
259
260
261
262
263

        # If auto-fit reduced max_model_len, sync the new value to workers.
        # This is needed because workers were spawned before memory profiling
        # and have the original (larger) max_model_len cached.
        max_model_len_after = vllm_config.model_config.max_model_len
        if max_model_len_after != max_model_len_before:
            self.collective_rpc("update_max_model_len", args=(max_model_len_after,))

264
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
265
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
266
        num_cpu_blocks = 0
267
268

        # Initialize kv cache and warmup the execution
269
        self.model_executor.initialize_from_config(kv_cache_configs)
270

271
        elapsed = time.time() - start
272
        logger.info_once(
273
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
274
            elapsed,
275
            scope="local",
276
        )
277
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
278

279
280
281
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

282
283
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
284

285
286
287
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
288
289
290
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
291
292
                f"request_id must be a string, got {type(request.request_id)}"
            )
293

294
        if pooling_params := request.pooling_params:
295
            supported_pooling_tasks = [
296
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
297
298
            ]

299
            if pooling_params.task not in supported_pooling_tasks:
300
301
302
303
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
304

305
        if request.kv_transfer_params is not None and (
306
307
308
309
310
311
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
312

313
        self.scheduler.add_request(request)
314

315
    def abort_requests(self, request_ids: list[str]):
316
317
318
319
320
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
321
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
322

323
324
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
325
        """Execute the model and log detailed info on failure."""
326
        try:
327
            yield
328
329
330
331
332
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

333
            # NOTE: This method is exception-free
334
335
336
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
337
338
            raise err

339
340
341
342
343
344
345
346
347
348
    def _log_err_callback(self, scheduler_output: SchedulerOutput):
        """Log error details of a future that's not expected to return a result."""

        def callback(f, sched_output=scheduler_output):
            with self.log_error_detail(sched_output):
                result = f.result()
                assert result is None

        return callback

349
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
350
351
352
353
354
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
355

356
357
358
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
359
            return {}, False
360
361
362
363
364
365
366
367
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
        with self.log_error_detail(scheduler_output):
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

368
369
370
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
371
372
373
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
374

375
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
376

377
    def post_step(self, model_executed: bool) -> None:
378
379
380
381
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
382
383
384
385
386
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

387
    def step_with_batch_queue(
388
        self,
389
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
390
391
392
393
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
394
395
396
397
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
398
399
400
401
402
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
403
404
        batch_queue = self.batch_queue
        assert batch_queue is not None
405

406
407
408
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
409
        assert len(batch_queue) < self.batch_queue_size
410

411
        model_executed = False
412
        deferred_scheduler_output = None
413
        if self.scheduler.has_requests():
414
415
416
417
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
418
            if not self.is_ec_producer:
419
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
420

421
            if self.is_pooling_model or not model_executed:
422
423
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
424
            else:
425
426
427
428
429
                exec_future.add_done_callback(self._log_err_callback(scheduler_output))

                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
430
431
432
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
433
434
435
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
436
                else:
437
438
439
440
441
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
442
443
444
445
446
447
448
449
450
451
                # Add this step's future to the queue.
                batch_queue.appendleft((future, scheduler_output))
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
452
453
454
455
456
457

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
458
459
460
461
462
463

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
        with self.log_error_detail(scheduler_output):
            model_output = future.result()

464
465
466
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
467
468
469
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
470
471
472
473
474

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
475
476
477
478
479
480
481
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
            batch_queue.appendleft((future, deferred_scheduler_output))
482

483
        return engine_core_outputs, model_executed
484

485
486
487
488
489
490
491
492
493
494
495
496
    def _process_aborts_queue(self):
        if not self.aborts_queue.empty():
            request_ids = []
            while not self.aborts_queue.empty():
                ids = self.aborts_queue.get_nowait()
                if isinstance(ids, str):
                    # Should be a list here, but also handle string just in case.
                    ids = (ids,)
                request_ids.extend(ids)
            # More efficient to abort all as a single batch.
            self.abort_requests(request_ids)

497
    def shutdown(self):
498
        self.structured_output_manager.clear_backend()
499
500
        if self.model_executor:
            self.model_executor.shutdown()
501
502
        if self.scheduler:
            self.scheduler.shutdown()
503

504
    def profile(self, is_start: bool = True):
505
        self.model_executor.profile(is_start)
506

507
508
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
509
        # re-sync the internal caches (P0 sender, P1 receiver)
510
        if self.scheduler.has_unfinished_requests():
511
512
513
514
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
515

516
        # The cache either exists in EngineCore or WorkerWrapperBase
517
518
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
519

520
521
        self.model_executor.reset_mm_cache()

522
523
524
525
526
527
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.scheduler.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
528

529
530
531
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

532
    def wake_up(self, tags: list[str] | None = None):
533
        self.model_executor.wake_up(tags)
534

535
536
537
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

538
    def execute_dummy_batch(self):
539
        self.model_executor.execute_dummy_batch()
540

541
542
543
544
545
546
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

547
    def list_loras(self) -> set[int]:
548
549
550
551
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
552

553
554
555
    def save_sharded_state(
        self,
        path: str,
556
557
        pattern: str | None = None,
        max_size: int | None = None,
558
    ) -> None:
559
560
561
562
563
564
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
565
566
        method: str | Callable[..., _R],
        timeout: float | None = None,
567
        args: tuple = (),
568
        kwargs: dict[str, Any] | None = None,
569
570
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
571

572
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
573
        """Preprocess the request.
574

575
576
577
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
578
579
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
580
        # and will only be accessed in the input processing thread afterwards.
581
        if self.mm_receiver_cache is not None and request.mm_features:
582
583
584
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
585

586
        req = Request.from_engine_core_request(request, self.request_block_hasher)
587
588
589
590
591
592
593
594
595
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

596
597
598
599

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

600
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
601

602
603
    def __init__(
        self,
604
        vllm_config: VllmConfig,
605
        local_client: bool,
606
        handshake_address: str,
607
        executor_class: type[Executor],
608
        log_stats: bool,
609
        client_handshake_address: str | None = None,
610
        engine_index: int = 0,
611
    ):
Rui Qiao's avatar
Rui Qiao committed
612
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
613
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
614
        executor_fail_callback = lambda: self.input_queue.put_nowait(
615
616
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
617

Rui Qiao's avatar
Rui Qiao committed
618
619
620
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
621

622
623
624
625
626
627
628
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
629
            self.client_count = len(addresses.outputs)
630
631

            # Set up data parallel environment.
632
            self.has_coordinator = addresses.coordinator_output is not None
633
            self.frontend_stats_publish_address = (
634
635
636
637
638
639
640
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
641
            # Only publish request queue stats to coordinator for "internal"
642
            # and "hybrid" LB modes .
643
644
            self.publish_dp_lb_stats = (
                self.has_coordinator
645
646
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
647

648
649
            self._init_data_parallel(vllm_config)

650
651
652
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
653

654
655
656
657
658
659
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
660
661
662
663
664
665
666
667
668
669
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
670
671
672
673
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
674
675
676
677
678
679
680
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
681
682
683
684
685
686
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
687
                    raise RuntimeError("Input socket thread died during startup")
688
689
690
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
691
    @contextmanager
692
693
694
695
696
697
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
698
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
699
    ) -> Generator[EngineZmqAddresses, None, None]:
700
701
702
703
704
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

705
        For DP>1 with internal load-balancing this is with the shared front-end
706
707
        process which may reside on a different node.

708
        For DP>1 with external or hybrid load-balancing, two handshakes are
709
        performed:
710
711
712
713
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
714
715
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
716
717
718
719
720
721

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
722
        input_ctx = zmq.Context()
723
        is_local = local_client and client_handshake_address is None
724
        headless = not local_client
725
726
727
728
729
730
731
732
733
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
734
735
736
737
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
738
            assert local_client
739
            local_handshake = self._perform_handshake(
740
741
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
757
        headless: bool,
758
        vllm_config: VllmConfig,
759
        parallel_config_to_update: ParallelConfig | None = None,
760
    ) -> Generator[EngineZmqAddresses, None, None]:
761
762
763
764
765
766
767
768
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
769
            # Register engine with front-end.
770
771
772
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
773
774
775
776
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
777
778
779
780
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
781
782
783
784
785
786
787
788
789
790
791
792

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
793
                )
794
795

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
796

797
    @staticmethod
798
    def startup_handshake(
799
800
        handshake_socket: zmq.Socket,
        local_client: bool,
801
        headless: bool,
802
        parallel_config: ParallelConfig | None = None,
803
    ) -> EngineZmqAddresses:
804
        # Send registration message.
805
        handshake_socket.send(
806
807
808
809
810
811
812
813
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
814
815

        # Receive initialization message.
816
        logger.debug("Waiting for init message from front-end.")
817
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
818
819
820
821
822
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
823
824
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
825
826
            init_bytes, type=EngineHandshakeMetadata
        )
827
828
        logger.debug("Received init message: %s", init_message)

829
830
831
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
832

833
        return init_message.addresses
834
835

    @staticmethod
836
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
837
838
        """Launch EngineCore busy loop in background process."""

839
840
841
842
843
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

844
845
846
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

847
848
849
850
851
852
853
854
855
856
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

857
        engine_core: EngineCoreProc | None = None
858
        try:
859
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
860
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
861
                set_process_title("EngineCore", f"DP{dp_rank}")
862
                decorate_logs()
863
864
865
866
867
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
868
                set_process_title("EngineCore")
869
                decorate_logs()
870
871
                engine_core = EngineCoreProc(*args, **kwargs)

872
873
            engine_core.run_busy_loop()

874
        except SystemExit:
875
            logger.debug("EngineCore exiting.")
876
            raise
877
878
879
880
881
882
883
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
884
885
886
887
        finally:
            if engine_core is not None:
                engine_core.shutdown()

888
889
890
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

891
892
893
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

894
895
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
896
            # 1) Poll the input queue until there is work to do.
897
898
899
900
901
902
903
904
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
905
906
907
908
909
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
910
911
912
913
914
915
916
            if self.input_queue.empty():
                # Drain aborts queue; all aborts are also processed via input_queue.
                with self.aborts_queue.mutex:
                    self.aborts_queue.queue.clear()
                if logger.isEnabledFor(DEBUG):
                    logger.debug("EngineCore waiting for work.")
                    waited = True
917
918
919
920
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
921
            logger.debug("EngineCore loop active.")
922
923
924
925
926
927

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

928
    def _process_engine_step(self) -> bool:
929
930
931
        """Called only when there are unfinished local requests."""

        # Step the engine core.
932
        outputs, model_executed = self.step_fn()
933
        # Put EngineCoreOutputs into the output queue.
934
        for output in outputs.items() if outputs else ():
935
            self.output_queue.put_nowait(output)
936
937
        # Post-step hook.
        self.post_step(model_executed)
938

939
940
941
942
943
944
945
        # If no model execution happened but there are waiting requests
        # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
        # background threads (like NIXL handshake) to make progress.
        # Without this, the tight polling loop can starve background threads.
        if not model_executed and self.scheduler.has_unfinished_requests():
            time.sleep(0.001)

946
947
        return model_executed

948
949
950
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
951
        """Dispatch request from client."""
952

953
        if request_type == EngineCoreRequestType.ADD:
954
955
            req, request_wave = request
            self.add_request(req, request_wave)
956
        elif request_type == EngineCoreRequestType.ABORT:
957
            self.abort_requests(request)
958
        elif request_type == EngineCoreRequestType.UTILITY:
959
            client_idx, call_id, method_name, args = request
960
961
962
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
963
964
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
965
966
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
967
968
969
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
970
            self.output_queue.put_nowait(
971
972
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
973
974
975
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
976
977
978
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
979
980
981
982

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
983
        arg type, try converting to msgspec object."""
984
985
986
987
988
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
989
990
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
991
            and issubclass(p.annotation, msgspec.Struct)
992
993
994
995
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
996

997
998
999
1000
1001
1002
1003
1004
1005
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
1006
1007
1008
1009
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
1010

1011
1012
1013
    def process_input_sockets(
        self,
        input_addresses: list[str],
1014
        coord_input_address: str | None,
1015
1016
1017
        identity: bytes,
        ready_event: threading.Event,
    ):
1018
1019
1020
        """Input socket IO thread."""

        # Msgpack serialization decoding.
1021
1022
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
1023

1024
1025
1026
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
1027
1028
1029
1030
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
1031
1032
1033
1034
1035
1036
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
1037
1038
1039
1040
1041
1042
1043
1044
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
1045
                # Send subscription message to coordinator.
1046
                coord_socket.send(b"\x01")
1047
1048
1049
1050
1051
1052
1053

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1054
                input_socket.send(b"")
1055
                poller.register(input_socket, zmq.POLLIN)
1056

1057
            if coord_socket is not None:
1058
1059
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1060
                poller.register(coord_socket, zmq.POLLIN)
1061

1062
1063
            ready_event.set()
            del ready_event
1064
1065
1066
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1067
1068
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1069
1070

                    # Deserialize the request data.
1071
                    request: Any
1072
                    if request_type == EngineCoreRequestType.ADD:
1073
1074
1075
1076
1077
1078
                        req: EngineCoreRequest = add_request_decoder.decode(data_frames)
                        try:
                            request = self.preprocess_add_request(req)
                        except Exception:
                            self._handle_request_preproc_error(req)
                            continue
1079
1080
                    else:
                        request = generic_decoder.decode(data_frames)
1081

1082
1083
1084
1085
1086
1087
1088
                        if request_type == EngineCoreRequestType.ABORT:
                            # Aborts are added to *both* queues, allows us to eagerly
                            # process aborts while also ensuring ordering in the input
                            # queue to avoid leaking requests. This is ok because
                            # aborting in the scheduler is idempotent.
                            self.aborts_queue.put_nowait(request)

1089
1090
1091
                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1092
1093
1094
    def process_output_sockets(
        self,
        output_paths: list[str],
1095
        coord_output_path: str | None,
1096
1097
        engine_index: int,
    ):
1098
1099
1100
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1101
        encoder = MsgpackEncoder()
1102
1103
1104
1105
1106
1107
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1108

1109
1110
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1111
1112
1113
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1114
1115
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1116
1117
                for output_path in output_paths
            ]
1118
1119
1120
1121
1122
1123
1124
1125
1126
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1127
1128
            max_reuse_bufs = len(sockets) + 1

1129
            while True:
1130
1131
1132
1133
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1134
                    break
1135
1136
                assert not isinstance(output, bytes)
                client_index, outputs = output
1137
                outputs.engine_index = engine_index
1138

1139
1140
1141
1142
1143
1144
1145
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1146
1147
1148
1149
1150
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1151
                buffers = encoder.encode_into(outputs, buffer)
1152
1153
1154
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1155
1156
1157
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1158
1159
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1160
                    reuse_buffers.append(buffer)
1161

1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
        """Log and return a request-scoped error response for exceptions raised
        from the add request preprocessing in the input socket processing thread.
        """
        logger.exception(
            "Unexpected error pre-processing request %s", request.request_id
        )
        self.output_queue.put_nowait(
            (
                request.client_index,
                EngineCoreOutputs(
                    engine_index=self.engine_index,
                    finished_requests={request.request_id},
                    outputs=[
                        EngineCoreOutput(
                            request_id=request.request_id,
                            new_token_ids=[],
                            finish_reason=FinishReason.ERROR,
                        )
                    ],
                ),
            )
        )

1186
1187
1188
1189
1190
1191
1192
1193

class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1194
        local_client: bool,
1195
        handshake_address: str,
1196
1197
        executor_class: type[Executor],
        log_stats: bool,
1198
        client_handshake_address: str | None = None,
1199
    ):
1200
1201
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1202
        self.step_counter = 0
1203
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1204
        self.last_counts = (0, 0)
1205
1206
1207

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1208
1209
1210
1211
1212
1213
1214
1215
1216
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1217
1218
1219

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1220
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1221
        dp_size = vllm_config.parallel_config.data_parallel_size
1222
1223
1224
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1225
        assert local_dp_rank is not None
1226
1227
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1228
1229
1230
1231
1232
1233
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1234
1235
1236
1237
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1238

1239
        self.dp_rank = dp_rank
1240
1241
1242
1243
1244
1245
1246
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1247
1248
1249
1250
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1251
1252
1253
1254
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1255
1256
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1257

1258
        super().add_request(request, request_wave)
1259

1260
1261
1262
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1263
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1264
1265
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1266
1267
                new_wave >= self.current_wave
            ):
1268
1269
                self.current_wave = new_wave
                if not self.engines_running:
1270
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1271
1272
1273
1274
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1275
    def _maybe_publish_request_counts(self):
1276
        if not self.publish_dp_lb_stats:
1277
1278
1279
1280
1281
1282
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1283
1284
1285
1286
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1287

1288
1289
1290
1291
1292
1293
1294
1295
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1296
1297
            # 2) Step the engine core.
            executed = self._process_engine_step()
1298
1299
            self._maybe_publish_request_counts()

1300
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1301
1302
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1303
1304
1305
                    # All engines are idle.
                    continue

1306
1307
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1308
1309
1310
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1311
            self.engines_running = self._has_global_unfinished_reqs(
1312
1313
                local_unfinished_reqs
            )
1314

1315
            if not self.engines_running:
1316
                if self.dp_rank == 0 or not self.has_coordinator:
1317
                    # Notify client that we are pausing the loop.
1318
1319
1320
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1321
1322
1323
1324
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1325
                    self.output_queue.put_nowait(
1326
1327
1328
1329
1330
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1331
                # Increment wave count and reset step counter.
1332
                self.current_wave += 1
1333
                self.step_counter = 0
1334
1335

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1336
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1337
1338
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1339
1340
            return True

1341
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1342

1343
    def reinitialize_distributed(
1344
1345
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1346
1347
1348
1349
1350
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1351
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1352
        if reconfig_request.new_data_parallel_rank != -1:
1353
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1354
        # local rank specifies device visibility, it should not be changed
1355
1356
1357
1358
1359
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1360
            reconfig_request.new_data_parallel_master_ip
1361
1362
        )
        parallel_config.data_parallel_master_port = (
1363
            reconfig_request.new_data_parallel_master_port
1364
        )
1365
1366
1367
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1368
        reconfig_request.new_data_parallel_master_port = (
1369
            parallel_config.data_parallel_master_port
1370
        )
1371
1372
1373
1374
1375
1376
1377
1378

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1379
1380
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1381
1382
1383
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1384
1385
1386
1387
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1388
1389
1390
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1391
1392
1393
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1394

Rui Qiao's avatar
Rui Qiao committed
1395
1396
1397
1398
1399
1400
1401
1402
1403

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1404
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1405
1406
1407
1408
1409
1410
1411
1412
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1413
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1414

1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1425
1426
1427
1428
1429
1430
1431
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1432
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1433

1434
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1435

1436
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1437
        from vllm.platforms import current_platform
1438

1439
1440
1441
1442
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1443
1444
1445
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1446

1447
1448
1449
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1450
1451
1452
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1453
1454
1455
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1456
            os.environ[device_control_env_var] = value
1457
1458
1459
1460
1461
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1462
1463
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1464

Rui Qiao's avatar
Rui Qiao committed
1465
    @contextmanager
1466
1467
1468
1469
1470
1471
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1472
        client_handshake_address: str | None,
1473
    ):
Rui Qiao's avatar
Rui Qiao committed
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()