core.py 32.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import os
3
import queue
4
import signal
5
import sys
6
7
import threading
import time
8
from collections import deque
9
from concurrent.futures import Future
10
from inspect import isclass, signature
11
from logging import DEBUG
12
from typing import Any, Callable, Optional, TypeVar, Union
13

14
import msgspec
15
16
import zmq

17
18
19
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
20
from vllm.logger import init_logger
21
from vllm.logging_utils.dump_input import dump_engine_exception
22
from vllm.lora.request import LoRARequest
23
24
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
25
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
26
27
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
28
from vllm.v1.core.sched.interface import SchedulerInterface
29
30
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
31
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
32
                            EngineCoreRequestType, UtilityOutput)
33
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
34
from vllm.v1.executor.abstract import Executor
35
from vllm.v1.kv_cache_interface import KVCacheConfig
36
from vllm.v1.outputs import ModelRunnerOutput
37
from vllm.v1.request import Request, RequestStatus
38
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
39
from vllm.v1.structured_output import StructuredOutputManager
40
41
42
43
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

44
POLLING_TIMEOUT_S = 2.5
45
HANDSHAKE_TIMEOUT_MINS = 5
46

47
48
_R = TypeVar('_R')  # Return type for collective_rpc

49
50
51
52

class EngineCore:
    """Inner loop of vLLM's Engine."""

53
54
55
56
57
    def __init__(self,
                 vllm_config: VllmConfig,
                 executor_class: type[Executor],
                 log_stats: bool,
                 executor_fail_callback: Optional[Callable] = None):
58
        assert vllm_config.model_config.runner_type != "pooling"
59

60
        self.vllm_config = vllm_config
61
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
62
63
                    VLLM_VERSION, vllm_config)

64
65
        self.log_stats = log_stats

66
67
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
68
69
70
        if executor_fail_callback is not None:
            self.model_executor.register_failure_callback(
                executor_fail_callback)
71
72

        # Setup KV Caches and update CacheConfig after profiling.
73
74
75
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

76
77
78
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

79
80
        self.structured_output_manager = StructuredOutputManager(vllm_config)

81
        # Setup scheduler.
82
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
83
84
85
86
87
88
89
90
91
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
92
93
94
95
96
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
97

98
        self.scheduler: SchedulerInterface = Scheduler(
99
            vllm_config=vllm_config,
100
101
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
102
103
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
104
            log_stats=self.log_stats,
105
        )
106

107
        # Setup MM Input Mapper.
108
        self.mm_input_cache_server = MirroredProcessingCache(
109
            vllm_config.model_config)
110

111
112
113
114
115
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
116
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
117
118
119
120
121
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)
122
        self.vllm_config = vllm_config
123

124
125
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
126
        start = time.time()
127

128
        # Get all kv cache needed by the model
129
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
130
131
132

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
133
        available_gpu_memory = self.model_executor.determine_available_memory()
134

135
        assert len(kv_cache_specs) == len(available_gpu_memory)
136
        # Get the kv cache tensor size
137
138
139
140
141
142
143
144
145
146
147
148
149
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
150
        # an arbitrary one to initialize the scheduler.
151
152
153
154
155
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
156
        num_cpu_blocks = 0
157
        scheduler_kv_cache_config = kv_cache_configs[0]
158
159

        # Initialize kv cache and warmup the execution
160
        self.model_executor.initialize_from_config(kv_cache_configs)
161

162
163
164
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
165
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
166
167
168

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
169
170

        if request.mm_hashes is not None:
171
172
173
174
175
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
176
            assert request.mm_inputs is not None
177
            request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
178
                request.mm_inputs, request.mm_hashes)
179

180
        req = Request.from_engine_core_request(request)
181
182
        if req.use_structured_output:
            # Start grammar compilation asynchronously
183
            self.structured_output_manager.grammar_init(req)
184

185
186
187
188
        if req.kv_transfer_params is not None and (
                not self.scheduler.get_kv_connector()):
            logger.warning("Got kv_transfer_params, but no KVConnector found. "
                           "Disabling KVTransfer for this request.")
Robert Shaw's avatar
Robert Shaw committed
189

190
191
        self.scheduler.add_request(req)

192
    def abort_requests(self, request_ids: list[str]):
193
194
195
196
197
198
199
200
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

201
202
203
204
205
206
207
208
209
210
    def execute_model(self, scheduler_output: SchedulerOutput):
        try:
            return self.model_executor.execute_model(scheduler_output)
        except BaseException as err:
            # NOTE: This method is exception-free
            dump_engine_exception(self.vllm_config, scheduler_output,
                                  self.scheduler.make_stats())
            # Re-raise exception
            raise err

211
    def step(self) -> EngineCoreOutputs:
212
213
        """Schedule, execute, and make output."""

214
215
216
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
217
            return EngineCoreOutputs(
218
219
220
                outputs=[],
                scheduler_stats=self.scheduler.make_stats(),
            )
221
        scheduler_output = self.scheduler.schedule()
222
        model_output = self.execute_model(scheduler_output)
223
        engine_core_outputs = self.scheduler.update_from_output(
224
            scheduler_output, model_output)  # type: ignore
225

226
227
228
229
230
231
232
        return engine_core_outputs

    def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
233
234
235
236
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
237
238
239
240
241
242
243
244
245
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
246
247
248
249
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
        if not self.batch_queue.full():
250
251
252
253
254
255
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

256
257
258
259
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
260
        # block until the first batch in the job queue is finished.
261
262
263
264
        # TODO(comaniac): Ideally we should peek the first batch in the
        # job queue to check if it's finished before scheduling a new batch,
        # but peeking the first element in a queue is not thread-safe,
        # so we need more work.
265
266
267
268
269
270
271
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
            # Blocking until the first result is available.
            model_output = future.result()
            self.batch_queue.task_done()
            engine_core_outputs = self.scheduler.update_from_output(
                scheduler_output, model_output)
272

273
274
        return engine_core_outputs

275
    def shutdown(self):
276
        self.structured_output_manager.clear_backend()
277
278
        if self.model_executor:
            self.model_executor.shutdown()
279
280
        if self.scheduler:
            self.scheduler.shutdown()
281

282
    def profile(self, is_start: bool = True):
283
        self.model_executor.profile(is_start)
284

285
286
287
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
288
        if self.scheduler.has_unfinished_requests():
289
290
291
292
293
            logger.warning("Resetting the multi-modal cache when requests are "
                           "in progress may lead to desynced internal caches.")

        self.mm_input_cache_server.reset()

294
295
296
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

297
298
299
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

300
301
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
302

303
304
305
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

306
307
308
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

309
310
311
312
313
314
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

315
    def list_loras(self) -> set[int]:
316
317
318
319
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
320

321
322
323
324
325
326
327
328
329
330
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_executor.save_sharded_state(path=path,
                                               pattern=pattern,
                                               max_size=max_size)

331
332
333
334
335
336
337
338
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

339
340
341
342

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

343
344
    ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'

345
346
    def __init__(
        self,
347
        vllm_config: VllmConfig,
348
349
        on_head_node: bool,
        input_address: str,
350
        executor_class: type[Executor],
351
        log_stats: bool,
352
        engine_index: int = 0,
353
    ):
354
355
356
357
358
        input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()

        executor_fail_callback = lambda: input_queue.put_nowait(
            (EngineCoreRequestType.EXECUTOR_FAILED, b''))

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        # Create input socket.
        input_ctx = zmq.Context()
        identity = engine_index.to_bytes(length=2, byteorder="little")
        input_socket = make_zmq_socket(input_ctx,
                                       input_address,
                                       zmq.DEALER,
                                       identity=identity,
                                       bind=False)
        try:
            # Register engine with front-end.
            output_address = self.startup_handshake(
                input_socket, on_head_node, vllm_config.parallel_config)

            # Update config which may have changed from the handshake.
            vllm_config.__post_init__()

            # Set up data parallel environment.
            self._init_data_parallel(vllm_config)

            # Initialize engine core and model.
            super().__init__(vllm_config, executor_class, log_stats,
                             executor_fail_callback)

            self.step_fn = (self.step if self.batch_queue is None else
                            self.step_with_batch_queue)
            self.engines_running = False

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
            input_socket.send(
                msgspec.msgpack.encode({
                    "status": "READY",
                    "local": on_head_node,
                    "num_gpu_blocks": num_gpu_blocks,
                }))

            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            self.input_queue = input_queue
            self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
            threading.Thread(target=self.process_input_socket,
                             args=(input_socket, ),
                             daemon=True).start()
            input_socket = None
            self.output_thread = threading.Thread(
                target=self.process_output_socket,
                args=(output_address, engine_index),
                daemon=True)
            self.output_thread.start()
        finally:
            if input_socket is not None:
                input_socket.close(linger=0)

    @staticmethod
    def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
                          parallel_config: ParallelConfig) -> str:

        # Send registration message.
        input_socket.send(
            msgspec.msgpack.encode({
                "status": "HELLO",
                "local": on_head_node,
            }))

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
        if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
            raise RuntimeError("Did not receive response from front-end "
                               f"process within {HANDSHAKE_TIMEOUT_MINS} "
                               f"minutes")
        init_bytes = input_socket.recv()
        init_message = msgspec.msgpack.decode(init_bytes)
        logger.debug("Received init message: %s", init_message)

        output_socket_address = init_message["output_socket_address"]
        #TBD(nick) maybe replace IP with configured head node address

        received_parallel_config = init_message["parallel_config"]
        for key, value in received_parallel_config.items():
            setattr(parallel_config, key, value)

        return output_socket_address
444
445

    @staticmethod
446
447
448
449
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        **kwargs):
450
451
        """Launch EngineCore busy loop in background process."""

452
453
454
455
456
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

457
458
459
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

460
461
462
463
464
465
466
467
468
469
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

470
        engine_core: Optional[EngineCoreProc] = None
471
        try:
472
473
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
474
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
475
476
477
478
479
480
481
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
                engine_core = EngineCoreProc(*args, **kwargs)

482
483
            engine_core.run_busy_loop()

484
        except SystemExit:
485
            logger.debug("EngineCore exiting.")
486
            raise
487
488
489
490
491
492
493
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
494
495
496
497
        finally:
            if engine_core is not None:
                engine_core.shutdown()

498
499
500
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

501
502
503
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

504
505
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
506
            # 1) Poll the input queue until there is work to do.
507
508
509
510
511
512
513
514
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
515
        while not self.engines_running and not (self.scheduler.has_requests()):
516
517
518
519
520
521
522
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
523
            logger.debug("EngineCore loop active.")
524
525
526
527
528
529
530
531
532
533
534
535
536
537

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

    def _process_engine_step(self):
        """Called only when there are unfinished local requests."""

        # Step the engine core.
        outputs = self.step_fn()
        # Put EngineCoreOutputs into the output queue.
        if outputs is not None:
            self.output_queue.put_nowait(outputs)
538

539
540
541
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
542

543
        if request_type == EngineCoreRequestType.ADD:
544
            self.add_request(request)
545
        elif request_type == EngineCoreRequestType.ABORT:
546
            self.abort_requests(request)
547
548
549
550
551
552
553
554
555
556
557
558
559
        elif request_type == EngineCoreRequestType.UTILITY:
            call_id, method_name, args = request
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
                EngineCoreOutputs(utility_output=output))
560
561
562
563
564
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
            logger.error("Unrecognized input request type encountered: %s",
                         request_type)
565
566
567
568
569
570
571
572
573
574
575
576
577
578

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
579

580
581
582
583
584
585
586
587
588
589
590
591
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
            logger.fatal("vLLM shutdown signal from EngineCore failed "
                         "to send. Please report this issue.")

592
    def process_input_socket(self, input_socket: zmq.Socket):
593
594
595
        """Input socket IO thread."""

        # Msgpack serialization decoding.
596
597
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
598

599
600
601
602
        while True:
            # (RequestType, RequestData)
            type_frame, *data_frames = input_socket.recv_multipart(copy=False)
            request_type = EngineCoreRequestType(bytes(type_frame.buffer))
603

604
605
606
607
            # Deserialize the request data.
            decoder = add_request_decoder if (
                request_type == EngineCoreRequestType.ADD) else generic_decoder
            request = decoder.decode(data_frames)
608

609
610
            # Push to input queue for core busy loop.
            self.input_queue.put_nowait((request_type, request))
611

612
    def process_output_socket(self, output_path: str, engine_index: int):
613
614
615
        """Output socket IO thread."""

        # Msgpack serialization encoding.
616
        encoder = MsgpackEncoder()
617
618
619
620
621
622
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
623

624
625
626
627
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
        with zmq_socket_ctx(output_path, zmq.constants.PUSH,
                            linger=4000) as socket:
628
            while True:
629
                outputs = self.output_queue.get()
630
631
632
633
                if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
                    socket.send(outputs, copy=False)
                    break
                assert not isinstance(outputs, bytes)
634
                outputs.engine_index = engine_index
635
636
637
638
639
640

                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
641
                buffers = encoder.encode_into(outputs, buffer)
642
643
644
645
646
647
648
649
650
                tracker = socket.send_multipart(buffers,
                                                copy=False,
                                                track=True)
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
                elif len(reuse_buffers) < 2:
                    # Keep at most 2 buffers to reuse.
                    reuse_buffers.append(buffer)
651
652
653
654
655
656
657
658
659


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
660
661
        on_head_node: bool,
        input_address: str,
662
663
664
665
666
667
668
669
670
671
672
        executor_class: type[Executor],
        log_stats: bool,
    ):
        # Add process-specific prefix to stdout and stderr before
        # we initialize the engine.
        from multiprocessing import current_process
        process_name = current_process().name
        pid = os.getpid()
        _add_prefix(sys.stdout, process_name, pid)
        _add_prefix(sys.stderr, process_name, pid)

673
674
675
676
677
678
679
680
681
682
683
684
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
        self.counter = 0

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
        super().__init__(vllm_config, on_head_node, input_address,
                         executor_class, log_stats, dp_rank)

    def _init_data_parallel(self, vllm_config: VllmConfig):

        # Configure GPUs and stateless process group for data parallel.
685
        dp_rank = vllm_config.parallel_config.data_parallel_rank
686
        dp_size = vllm_config.parallel_config.data_parallel_size
687
688
689
690
691
692
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

        from vllm.platforms import current_platform
693
        device_control_env_var = current_platform.device_control_env_var
694
        world_size = vllm_config.parallel_config.world_size
695
696
        os.environ[device_control_env_var] = ",".join(
            str(current_platform.device_id_to_physical_device_id(i))
697
698
            for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
                           world_size))
699

700
        self.local_dp_rank = local_dp_rank
701
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
702
        self.current_wave = 0
703
704
705
706
707
708

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
    def add_request(self, request: EngineCoreRequest):
        if request.current_wave != self.current_wave:
            if request.current_wave > self.current_wave:
                self.current_wave = request.current_wave
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
                    EngineCoreOutputs(start_wave=self.current_wave))

        super().add_request(request)

    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        if request_type == EngineCoreRequestType.START_DP_WAVE:
            new_wave: int = request
            if new_wave >= self.current_wave:
                self.current_wave = new_wave
                if not self.engines_running:
                    logger.debug("EngineCore starting idle loop for wave %d.",
                                 new_wave)
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

            local_unfinished_reqs = self.scheduler.has_unfinished_requests()

            if local_unfinished_reqs:
                # 2) Step the engine core.
                self._process_engine_step()

                # Check if we have now finished all requests.
                local_unfinished_reqs = (
                    self.scheduler.has_unfinished_requests())
            else:
                if self.scheduler.has_finished_requests():
                    # There are no unfinished requests, but there are some
                    # finished requests remaining to be removed from the
                    # batch state. This engine step won't perform a forward
                    # pass but will flush the finished requests to ensure
                    # up-to-date state is returned in the engine outputs.
                    self._process_engine_step()

760
                if not self.engines_running:
761
762
763
764
765
766
767
768
                    # All engines are idle.
                    continue

                # There must be unfinished requests in DP peers, run a
                # dummy forward pass.
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
769
            self.engines_running = self._has_global_unfinished_reqs(
770
771
                local_unfinished_reqs)

772
773
774
775
776
777
778
779
            if not self.engines_running:
                if self.local_dp_rank == 0:
                    # Notify client that we are pausing the loop.
                    logger.debug("Wave %d finished, pausing engine loop.",
                                 self.current_wave)
                    self.output_queue.put_nowait(
                        EngineCoreOutputs(wave_complete=self.current_wave))
                self.current_wave += 1
780
781
782

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

783
        # Optimization - only perform finish-sync all-reduce every 24 steps.
784
        self.counter += 1
785
        if self.counter != 24:
786
787
788
789
790
            return True
        self.counter = 0

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)