core.py 38.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
import sys
7
8
import threading
import time
9
from collections import deque
Rui Qiao's avatar
Rui Qiao committed
10
from collections.abc import Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, Callable, Optional, TypeVar, Union
16

17
import msgspec
18
19
import zmq

20
21
22
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
23
from vllm.logger import init_logger
24
from vllm.logging_utils.dump_input import dump_engine_exception
25
from vllm.lora.request import LoRARequest
26
27
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
28
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
29
30
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
31
from vllm.v1.core.sched.interface import SchedulerInterface
32
33
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
34
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
35
                            EngineCoreRequestType, UtilityOutput)
36
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
37
from vllm.v1.executor.abstract import Executor
38
from vllm.v1.kv_cache_interface import KVCacheConfig
39
from vllm.v1.metrics.stats import SchedulerStats
40
from vllm.v1.outputs import ModelRunnerOutput
41
from vllm.v1.request import Request, RequestStatus
42
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
43
from vllm.v1.structured_output import StructuredOutputManager
44
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
45
46
47
48
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

49
POLLING_TIMEOUT_S = 2.5
50
HANDSHAKE_TIMEOUT_MINS = 5
51

52
53
_R = TypeVar('_R')  # Return type for collective_rpc

54
55
56
57

class EngineCore:
    """Inner loop of vLLM's Engine."""

58
59
60
61
62
    def __init__(self,
                 vllm_config: VllmConfig,
                 executor_class: type[Executor],
                 log_stats: bool,
                 executor_fail_callback: Optional[Callable] = None):
63
        assert vllm_config.model_config.runner_type != "pooling"
64

65
66
67
68
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
        load_general_plugins()

69
        self.vllm_config = vllm_config
70
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
71
72
                    VLLM_VERSION, vllm_config)

73
74
        self.log_stats = log_stats

75
76
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
77
78
79
        if executor_fail_callback is not None:
            self.model_executor.register_failure_callback(
                executor_fail_callback)
80
81

        # Setup KV Caches and update CacheConfig after profiling.
82
83
84
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

85
86
87
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

88
89
        self.structured_output_manager = StructuredOutputManager(vllm_config)

90
        # Setup scheduler.
91
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
92
93
94
95
96
97
98
99
100
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
101
102
103
104
105
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
106

107
        self.scheduler: SchedulerInterface = Scheduler(
108
            vllm_config=vllm_config,
109
110
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
111
112
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
113
            log_stats=self.log_stats,
114
        )
115

116
        # Setup MM Input Mapper.
117
        self.mm_input_cache_server = MirroredProcessingCache(
118
            vllm_config.model_config)
119

120
121
122
123
124
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
125
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
126
127
128
129
130
131
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

132
133
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
134
        start = time.time()
135

136
        # Get all kv cache needed by the model
137
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
138
139
140

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
141
        available_gpu_memory = self.model_executor.determine_available_memory()
142

143
        assert len(kv_cache_specs) == len(available_gpu_memory)
144
        # Get the kv cache tensor size
145
146
147
148
149
150
151
152
153
154
155
156
157
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
158
        # an arbitrary one to initialize the scheduler.
159
160
161
162
163
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
164
        num_cpu_blocks = 0
165
        scheduler_kv_cache_config = kv_cache_configs[0]
166
167

        # Initialize kv cache and warmup the execution
168
        self.model_executor.initialize_from_config(kv_cache_configs)
169

170
171
172
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
173
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
174
175
176

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
177
178

        if request.mm_hashes is not None:
179
180
181
182
183
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
184
            assert request.mm_inputs is not None
185
            request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
186
                request.mm_inputs, request.mm_hashes)
187

188
        req = Request.from_engine_core_request(request)
189
190
        if req.use_structured_output:
            # Start grammar compilation asynchronously
191
            self.structured_output_manager.grammar_init(req)
192

193
194
195
196
        if req.kv_transfer_params is not None and (
                not self.scheduler.get_kv_connector()):
            logger.warning("Got kv_transfer_params, but no KVConnector found. "
                           "Disabling KVTransfer for this request.")
Robert Shaw's avatar
Robert Shaw committed
197

198
199
        self.scheduler.add_request(req)

200
    def abort_requests(self, request_ids: list[str]):
201
202
203
204
205
206
207
208
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

209
210
211
212
213
214
215
216
217
218
    def execute_model(self, scheduler_output: SchedulerOutput):
        try:
            return self.model_executor.execute_model(scheduler_output)
        except BaseException as err:
            # NOTE: This method is exception-free
            dump_engine_exception(self.vllm_config, scheduler_output,
                                  self.scheduler.make_stats())
            # Re-raise exception
            raise err

219
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
220
221
222
223
224
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
225

226
227
228
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
229
            return {}, False
230
        scheduler_output = self.scheduler.schedule()
231
        model_output = self.execute_model(scheduler_output)
232
        engine_core_outputs = self.scheduler.update_from_output(
233
            scheduler_output, model_output)  # type: ignore
234

235
236
        return (engine_core_outputs,
                scheduler_output.total_num_scheduled_tokens > 0)
237

238
    def step_with_batch_queue(
239
            self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
240
241
242
243
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
244
245
246
247
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
248
249
250
251
252
253
254
255
256
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
257
258
259
260
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
        if not self.batch_queue.full():
261
262
263
264
265
266
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

267
268
269
270
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
271
        # block until the first batch in the job queue is finished.
272
273
274
275
        # TODO(comaniac): Ideally we should peek the first batch in the
        # job queue to check if it's finished before scheduling a new batch,
        # but peeking the first element in a queue is not thread-safe,
        # so we need more work.
276
277
278
279
280
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
            # Blocking until the first result is available.
            model_output = future.result()
            self.batch_queue.task_done()
281
282
            engine_core_outputs = (self.scheduler.update_from_output(
                scheduler_output, model_output))
283

284
        return engine_core_outputs, scheduled_batch
285

286
    def shutdown(self):
287
        self.structured_output_manager.clear_backend()
288
289
        if self.model_executor:
            self.model_executor.shutdown()
290
291
        if self.scheduler:
            self.scheduler.shutdown()
292

293
    def profile(self, is_start: bool = True):
294
        self.model_executor.profile(is_start)
295

296
297
298
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
299
        if self.scheduler.has_unfinished_requests():
300
301
302
303
304
            logger.warning("Resetting the multi-modal cache when requests are "
                           "in progress may lead to desynced internal caches.")

        self.mm_input_cache_server.reset()

305
306
307
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

308
309
310
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

311
312
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
313

314
315
316
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

317
318
319
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

320
321
322
323
324
325
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

326
    def list_loras(self) -> set[int]:
327
328
329
330
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
331

332
333
334
335
336
337
338
339
340
341
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_executor.save_sharded_state(path=path,
                                               pattern=pattern,
                                               max_size=max_size)

342
343
344
345
346
347
348
349
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

350
351
352
353
354
355
356
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

357
358
359
360

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

361
362
    ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'

363
364
    def __init__(
        self,
365
        vllm_config: VllmConfig,
366
        on_head_node: bool,
367
        handshake_address: str,
368
        executor_class: type[Executor],
369
        log_stats: bool,
370
        engine_index: int = 0,
371
    ):
Rui Qiao's avatar
Rui Qiao committed
372
373
374
375
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
        self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
                                              bytes]]()
        executor_fail_callback = lambda: self.input_queue.put_nowait(
376
377
            (EngineCoreRequestType.EXECUTOR_FAILED, b''))

Rui Qiao's avatar
Rui Qiao committed
378
379
380
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
381

Rui Qiao's avatar
Rui Qiao committed
382
383
        with self._perform_handshake(handshake_address, identity, on_head_node,
                                     vllm_config) as addresses:
384
            self.client_count = len(addresses.outputs)
385
386

            # Set up data parallel environment.
387
            self.has_coordinator = addresses.coordinator_output is not None
388
389
390
391
392
            self._init_data_parallel(vllm_config)

            super().__init__(vllm_config, executor_class, log_stats,
                             executor_fail_callback)

Rui Qiao's avatar
Rui Qiao committed
393
394
        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)
395

396
397
398
399
400
401
402
403
404
405
406
407
        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
        threading.Thread(target=self.process_input_sockets,
                         args=(addresses.inputs, addresses.coordinator_input,
                               identity),
                         daemon=True).start()
        self.output_thread = threading.Thread(
            target=self.process_output_sockets,
            args=(addresses.outputs, addresses.coordinator_output,
Rui Qiao's avatar
Rui Qiao committed
408
                  self.engine_index),
409
410
            daemon=True)
        self.output_thread.start()
411

Rui Qiao's avatar
Rui Qiao committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    @contextmanager
    def _perform_handshake(
            self, handshake_address: str, identity: bytes, on_head_node: bool,
            vllm_config: VllmConfig
    ) -> Generator[EngineZmqAddresses, None, None]:
        input_ctx = zmq.Context()
        with make_zmq_socket(input_ctx,
                             handshake_address,
                             zmq.DEALER,
                             identity=identity,
                             linger=5000,
                             bind=False) as handshake_socket:
            # Register engine with front-end.
            addresses = self.startup_handshake(handshake_socket, on_head_node,
                                               vllm_config.parallel_config)

            # Update config which may have changed from the handshake
            vllm_config.__post_init__()

            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
            handshake_socket.send(
                msgspec.msgpack.encode({
                    "status": "READY",
                    "local": on_head_node,
                    "num_gpu_blocks": num_gpu_blocks,
                }))

442
    @staticmethod
443
444
445
    def startup_handshake(
            handshake_socket: zmq.Socket, on_head_node: bool,
            parallel_config: ParallelConfig) -> EngineZmqAddresses:
446
447

        # Send registration message.
448
        handshake_socket.send(
449
450
451
452
453
454
455
            msgspec.msgpack.encode({
                "status": "HELLO",
                "local": on_head_node,
            }))

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
456
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
457
458
459
            raise RuntimeError("Did not receive response from front-end "
                               f"process within {HANDSHAKE_TIMEOUT_MINS} "
                               f"minutes")
460
461
462
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
            init_bytes, type=EngineHandshakeMetadata)
463
464
        logger.debug("Received init message: %s", init_message)

465
        received_parallel_config = init_message.parallel_config
466
467
468
        for key, value in received_parallel_config.items():
            setattr(parallel_config, key, value)

469
        return init_message.addresses
470
471

    @staticmethod
472
473
474
475
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        **kwargs):
476
477
        """Launch EngineCore busy loop in background process."""

478
479
480
481
482
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

483
484
485
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

486
487
488
489
490
491
492
493
494
495
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

496
        engine_core: Optional[EngineCoreProc] = None
497
        try:
498
499
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
500
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
501
502
503
504
505
506
507
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
                engine_core = EngineCoreProc(*args, **kwargs)

508
509
            engine_core.run_busy_loop()

510
        except SystemExit:
511
            logger.debug("EngineCore exiting.")
512
            raise
513
514
515
516
517
518
519
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
520
521
522
523
        finally:
            if engine_core is not None:
                engine_core.shutdown()

524
525
526
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

527
528
529
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

530
531
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
532
            # 1) Poll the input queue until there is work to do.
533
534
535
536
537
538
539
540
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
541
        while not self.engines_running and not self.scheduler.has_requests():
542
543
544
545
546
547
548
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
549
            logger.debug("EngineCore loop active.")
550
551
552
553
554
555

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

556
    def _process_engine_step(self) -> bool:
557
558
559
        """Called only when there are unfinished local requests."""

        # Step the engine core.
560
        outputs, model_executed = self.step_fn()
561
        # Put EngineCoreOutputs into the output queue.
562
563
        for output in (outputs.items() if outputs else ()):
            self.output_queue.put_nowait(output)
564

565
566
        return model_executed

567
568
569
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
570

571
        if request_type == EngineCoreRequestType.ADD:
572
            self.add_request(request)
573
        elif request_type == EngineCoreRequestType.ABORT:
574
            self.abort_requests(request)
575
        elif request_type == EngineCoreRequestType.UTILITY:
576
            client_idx, call_id, method_name, args = request
577
578
579
580
581
582
583
584
585
586
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
587
                (client_idx, EngineCoreOutputs(utility_output=output)))
588
589
590
591
592
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
            logger.error("Unrecognized input request type encountered: %s",
                         request_type)
593
594
595
596
597
598
599
600
601
602
603
604
605
606

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
607

608
609
610
611
612
613
614
615
616
617
618
619
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
            logger.fatal("vLLM shutdown signal from EngineCore failed "
                         "to send. Please report this issue.")

620
621
622
    def process_input_sockets(self, input_addresses: list[str],
                              coord_input_address: Optional[str],
                              identity: bytes):
623
624
625
        """Input socket IO thread."""

        # Msgpack serialization decoding.
626
627
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
628

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx,
                                    input_address,
                                    zmq.DEALER,
                                    identity=identity,
                                    bind=False))
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
                    make_zmq_socket(ctx,
                                    coord_input_address,
                                    zmq.XSUB,
                                    identity=identity,
                                    bind=False))
                # Send subscription message to coordinator.
                coord_socket.send(b'\x01')

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
                input_socket.send(b'')
                poller.register(input_socket, zmq.POLLIN)
            if coord_socket is not None:
                poller.register(coord_socket, zmq.POLLIN)
661

662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
                    type_frame, *data_frames = input_socket.recv_multipart(
                        copy=False)
                    request_type = EngineCoreRequestType(
                        bytes(type_frame.buffer))

                    # Deserialize the request data.
                    decoder = add_request_decoder if (
                        request_type
                        == EngineCoreRequestType.ADD) else generic_decoder
                    request = decoder.decode(data_frames)

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

    def process_output_sockets(self, output_paths: list[str],
                               coord_output_path: Optional[str],
                               engine_index: int):
682
683
684
        """Output socket IO thread."""

        # Msgpack serialization encoding.
685
        encoder = MsgpackEncoder()
686
687
688
689
690
691
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
692

693
694
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
695
696
697
698
699
700
701
702
703
704
705
706
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
                for output_path in output_paths
            ]
            coord_socket = stack.enter_context(
                make_zmq_socket(
                    ctx, coord_output_path, zmq.PUSH, bind=False,
                    linger=4000)) if coord_output_path is not None else None
            max_reuse_bufs = len(sockets) + 1

707
            while True:
708
709
710
711
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
712
                    break
713
714
                assert not isinstance(output, bytes)
                client_index, outputs = output
715
                outputs.engine_index = engine_index
716

717
718
719
720
721
722
723
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

724
725
726
727
728
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
729
                buffers = encoder.encode_into(outputs, buffer)
730
731
732
                tracker = sockets[client_index].send_multipart(buffers,
                                                               copy=False,
                                                               track=True)
733
734
735
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
736
737
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
738
                    reuse_buffers.append(buffer)
739
740
741
742
743
744
745
746
747


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
748
        on_head_node: bool,
749
        handshake_address: str,
750
751
752
        executor_class: type[Executor],
        log_stats: bool,
    ):
Rui Qiao's avatar
Rui Qiao committed
753
754

        self._decorate_logs()
755

756
757
758
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
        self.counter = 0
759
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
760
        self.last_counts = (0, 0)
761
762
763

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
764
        super().__init__(vllm_config, on_head_node, handshake_address,
765
766
                         executor_class, log_stats, dp_rank)

Rui Qiao's avatar
Rui Qiao committed
767
768
769
770
771
772
773
774
775
    def _decorate_logs(self):
        # Add process-specific prefix to stdout and stderr before
        # we initialize the engine.
        from multiprocessing import current_process
        process_name = current_process().name
        pid = os.getpid()
        _add_prefix(sys.stdout, process_name, pid)
        _add_prefix(sys.stderr, process_name, pid)

776
777
778
    def _init_data_parallel(self, vllm_config: VllmConfig):

        # Configure GPUs and stateless process group for data parallel.
779
        dp_rank = vllm_config.parallel_config.data_parallel_rank
780
        dp_size = vllm_config.parallel_config.data_parallel_size
781
782
783
784
785
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

786
787
788
789
790
791
792
793
794
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
            logger.debug("Setting kv_transfer_config.engine_id to %s",
                         vllm_config.kv_transfer_config.engine_id)

795
        from vllm.platforms import current_platform
796
        device_control_env_var = current_platform.device_control_env_var
797
        world_size = vllm_config.parallel_config.world_size
798
799
        os.environ[device_control_env_var] = ",".join(
            str(current_platform.device_id_to_physical_device_id(i))
800
801
            for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
                           world_size))
802

803
        self.dp_rank = dp_rank
804
805
806
807
808
809
810
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

811
    def add_request(self, request: EngineCoreRequest):
812
        if self.has_coordinator and request.current_wave != self.current_wave:
813
814
815
816
817
818
            if request.current_wave > self.current_wave:
                self.current_wave = request.current_wave
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
819
                    (-1, EngineCoreOutputs(start_wave=self.current_wave)))
820
821
822
823
824
825

        super().add_request(request)

    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        if request_type == EngineCoreRequestType.START_DP_WAVE:
826
827
828
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
                    new_wave >= self.current_wave):
829
830
831
832
833
834
835
836
                self.current_wave = new_wave
                if not self.engines_running:
                    logger.debug("EngineCore starting idle loop for wave %d.",
                                 new_wave)
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

837
838
839
840
841
842
843
844
845
846
847
848
    def _maybe_publish_request_counts(self):
        if not self.has_coordinator:
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
            stats = SchedulerStats(*counts)
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(scheduler_stats=stats)))

849
850
851
852
853
854
855
856
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

857
858
            # 2) Step the engine core.
            executed = self._process_engine_step()
859
860
            self._maybe_publish_request_counts()

861
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
862
863
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
864
865
866
                    # All engines are idle.
                    continue

867
868
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
869
870
871
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
872
            self.engines_running = self._has_global_unfinished_reqs(
873
874
                local_unfinished_reqs)

875
            if not self.engines_running:
876
                if self.dp_rank == 0:
877
878
879
880
                    # Notify client that we are pausing the loop.
                    logger.debug("Wave %d finished, pausing engine loop.",
                                 self.current_wave)
                    self.output_queue.put_nowait(
881
882
                        (-1,
                         EngineCoreOutputs(wave_complete=self.current_wave)))
883
                self.current_wave += 1
884
885
886

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

887
        # Optimization - only perform finish-sync all-reduce every 24 steps.
888
        self.counter += 1
889
        if self.counter != 24:
890
891
892
893
894
            return True
        self.counter = 0

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961


class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        on_head_node: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
        vllm_config.parallel_config.data_parallel_rank_local = \
            local_dp_rank

        # Ray sets CUDA_VISIBLE_DEVICES to empty string,
        # we clean this up to be able to properly initialize
        # data parallel groups.
        del os.environ['CUDA_VISIBLE_DEVICES']

        super().__init__(vllm_config, on_head_node, "", executor_class,
                         log_stats)

    def _decorate_logs(self):
        pass

    @contextmanager
    def _perform_handshake(self, handshake_address: str, identity: bytes,
                           on_head_node: bool, vllm_config: VllmConfig):
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()