core.py 43.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
import sys
7
8
import threading
import time
9
from collections import deque
Rui Qiao's avatar
Rui Qiao committed
10
from collections.abc import Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, Callable, Optional, TypeVar, Union
16

17
import msgspec
lizhigong's avatar
lizhigong committed
18
19
from vllm import envs
from vllm.zero_overhead.v1.core import engine_core_step
20
21
import zmq

22
23
24
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
25
from vllm.logger import init_logger
26
from vllm.logging_utils.dump_input import dump_engine_exception
27
from vllm.lora.request import LoRARequest
28
29
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
30
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
31
32
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
33
from vllm.v1.core.sched.interface import SchedulerInterface
34
35
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
36
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
37
                            EngineCoreRequestType, UtilityOutput)
38
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
39
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
40
from vllm.v1.executor.abstract import Executor
41
from vllm.v1.kv_cache_interface import KVCacheConfig
42
from vllm.v1.metrics.stats import SchedulerStats
43
from vllm.v1.outputs import ModelRunnerOutput
44
from vllm.v1.request import Request, RequestStatus
45
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
46
from vllm.v1.structured_output import StructuredOutputManager
47
48
49
50
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

51
POLLING_TIMEOUT_S = 2.5
52
HANDSHAKE_TIMEOUT_MINS = 5
53

54
55
_R = TypeVar('_R')  # Return type for collective_rpc

56
57
58
59

class EngineCore:
    """Inner loop of vLLM's Engine."""

60
61
62
63
64
    def __init__(self,
                 vllm_config: VllmConfig,
                 executor_class: type[Executor],
                 log_stats: bool,
                 executor_fail_callback: Optional[Callable] = None):
65

66
67
68
69
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
        load_general_plugins()

70
        self.vllm_config = vllm_config
71
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
72
73
                    VLLM_VERSION, vllm_config)

74
75
        self.log_stats = log_stats

76
77
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
78
79
80
        if executor_fail_callback is not None:
            self.model_executor.register_failure_callback(
                executor_fail_callback)
81
82

        # Setup KV Caches and update CacheConfig after profiling.
83
84
85
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

86
87
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
88
89
        self.collective_rpc("initialize_cache",
                            args=(num_gpu_blocks, num_cpu_blocks))
90

91
92
        self.structured_output_manager = StructuredOutputManager(vllm_config)

93
        # Setup scheduler.
94
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
95
96
97
98
99
100
101
102
103
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
104
105
106
107
108
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
109

110
        self.scheduler: SchedulerInterface = Scheduler(
111
            vllm_config=vllm_config,
112
113
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
114
115
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
116
            log_stats=self.log_stats,
117
        )
118

119
        # Setup MM Input Mapper.
120
        self.mm_input_cache_server = MirroredProcessingCache(
121
            vllm_config.model_config)
122

123
124
125
126
127
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
128
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
129
130
131
132
133
134
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

135
136
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
137
        start = time.time()
138

139
        # Get all kv cache needed by the model
140
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
141
142
143

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
144
        available_gpu_memory = self.model_executor.determine_available_memory()
145

146
        assert len(kv_cache_specs) == len(available_gpu_memory)
147
        # Get the kv cache tensor size
148
149
150
151
152
153
154
155
156
157
158
159
160
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
161
        # an arbitrary one to initialize the scheduler.
162
163
164
165
166
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
167
        num_cpu_blocks = 0
168
        scheduler_kv_cache_config = kv_cache_configs[0]
169
170

        # Initialize kv cache and warmup the execution
171
        self.model_executor.initialize_from_config(kv_cache_configs)
172

173
174
175
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
176
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
177
178
179

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
180
181

        if request.mm_hashes is not None:
182
183
184
185
186
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
187
            assert request.mm_inputs is not None
188
            request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
189
                request.mm_inputs, request.mm_hashes)
190

191
        req = Request.from_engine_core_request(request)
192
193
        if req.use_structured_output:
            # Start grammar compilation asynchronously
194
            self.structured_output_manager.grammar_init(req)
195

196
197
198
199
        if req.kv_transfer_params is not None and (
                not self.scheduler.get_kv_connector()):
            logger.warning("Got kv_transfer_params, but no KVConnector found. "
                           "Disabling KVTransfer for this request.")
Robert Shaw's avatar
Robert Shaw committed
200

201
202
        self.scheduler.add_request(req)

203
    def abort_requests(self, request_ids: list[str]):
204
205
206
207
208
209
210
211
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

212
213
214
    def execute_model(self, scheduler_output: SchedulerOutput):
        try:
            return self.model_executor.execute_model(scheduler_output)
215
216
217
218
219
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

220
221
222
223
224
            # NOTE: This method is exception-free
            dump_engine_exception(self.vllm_config, scheduler_output,
                                  self.scheduler.make_stats())
            raise err

225
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
226
227
228
229
230
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
lizhigong's avatar
lizhigong committed
231
232
        if envs.VLLM_ZERO_OVERHEAD:
            return engine_core_step(self)
233

234
235
236
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
237
            return {}, False
238
        scheduler_output = self.scheduler.schedule()
239
        model_output = self.execute_model(scheduler_output)
240
        engine_core_outputs = self.scheduler.update_from_output(
241
            scheduler_output, model_output)  # type: ignore
242
243
        return (engine_core_outputs,
                scheduler_output.total_num_scheduled_tokens > 0)
244

245
    def step_with_batch_queue(
246
            self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
247
248
249
250
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
251
252
253
254
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
255
256
257
258
259
260
261
262
263
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
264
265
266
267
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
        if not self.batch_queue.full():
268
269
270
271
272
273
274
275
            # Tell the scheduler to try minimal progress injection only when there is no running batch (queue is empty),
            # to avoid unnecessary 0-token batches while a batch is already running.
            try:
                self.scheduler.set_allow_minimal_injection(
                    self.batch_queue.empty())
            except Exception:
                # Schedulers without this hook simply ignore it.
                pass
276
277
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
278
279
280
281
282
283
                if envs.VLLM_PP_DEBUG:
                    import sys,os
                    num_run_reqs = len(scheduler_output.scheduled_new_reqs) + scheduler_output.scheduled_cached_reqs.num_reqs
                    sys.stderr.write(f"[pid- {os.getpid()}]running requests in micro batch is:{num_run_reqs}, "
                                     f"total_num_scheduled_tokens is {scheduler_output.total_num_scheduled_tokens}\n")
                    sys.stderr.flush()
284
285
286
287
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

288
289
290
291
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
292
        # block until the first batch in the job queue is finished.
293
294
295
296
        # TODO(comaniac): Ideally we should peek the first batch in the
        # job queue to check if it's finished before scheduling a new batch,
        # but peeking the first element in a queue is not thread-safe,
        # so we need more work.
297
298
299
300
301
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
            # Blocking until the first result is available.
            model_output = future.result()
            self.batch_queue.task_done()
302
303
            engine_core_outputs = (self.scheduler.update_from_output(
                scheduler_output, model_output))
304

305
        return engine_core_outputs, scheduled_batch
306

307
    def shutdown(self):
308
        self.structured_output_manager.clear_backend()
309
310
        if self.model_executor:
            self.model_executor.shutdown()
311
312
        if self.scheduler:
            self.scheduler.shutdown()
313

314
    def profile(self, is_start: bool = True):
315
        self.model_executor.profile(is_start)
316

317
318
319
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
320
        if self.scheduler.has_unfinished_requests():
321
322
323
324
325
            logger.warning("Resetting the multi-modal cache when requests are "
                           "in progress may lead to desynced internal caches.")

        self.mm_input_cache_server.reset()

326
327
328
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

329
330
331
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

332
333
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
334

335
336
337
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

338
339
340
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

341
342
343
344
345
346
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

347
    def list_loras(self) -> set[int]:
348
349
350
351
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
352

353
354
355
356
357
358
359
360
361
362
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_executor.save_sharded_state(path=path,
                                               pattern=pattern,
                                               max_size=max_size)

363
364
365
366
367
368
369
370
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

371
372
373
374
375
376
377
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

378
379
380
381

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

382
383
    ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'

384
385
    def __init__(
        self,
386
        vllm_config: VllmConfig,
387
        local_client: bool,
388
        handshake_address: str,
389
        executor_class: type[Executor],
390
        log_stats: bool,
391
        client_handshake_address: Optional[str] = None,
392
        engine_index: int = 0,
393
    ):
Rui Qiao's avatar
Rui Qiao committed
394
395
396
397
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
        self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
                                              bytes]]()
        executor_fail_callback = lambda: self.input_queue.put_nowait(
398
399
            (EngineCoreRequestType.EXECUTOR_FAILED, b''))

Rui Qiao's avatar
Rui Qiao committed
400
401
402
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
403

404
405
406
        with self._perform_handshakes(handshake_address, identity,
                                      local_client, vllm_config,
                                      client_handshake_address) as addresses:
407
            self.client_count = len(addresses.outputs)
408
409

            # Set up data parallel environment.
410
            self.has_coordinator = addresses.coordinator_output is not None
411
412
413
414
415
416
417
418
            self.frontend_stats_publish_address = (
                addresses.frontend_stats_publish_address)
            # Only publish request queue stats to coordinator for "internal"
            # LB mode.
            self.publish_dp_lb_stats = (
                self.has_coordinator
                and not vllm_config.parallel_config.data_parallel_external_lb)

419
420
421
422
423
            self._init_data_parallel(vllm_config)

            super().__init__(vllm_config, executor_class, log_stats,
                             executor_fail_callback)

Rui Qiao's avatar
Rui Qiao committed
424
425
        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)
426

427
428
429
430
431
432
433
434
435
436
437
438
        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
        threading.Thread(target=self.process_input_sockets,
                         args=(addresses.inputs, addresses.coordinator_input,
                               identity),
                         daemon=True).start()
        self.output_thread = threading.Thread(
            target=self.process_output_sockets,
            args=(addresses.outputs, addresses.coordinator_output,
Rui Qiao's avatar
Rui Qiao committed
439
                  self.engine_index),
440
441
            daemon=True)
        self.output_thread.start()
442

Rui Qiao's avatar
Rui Qiao committed
443
    @contextmanager
444
445
446
447
448
449
450
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
        client_handshake_address: Optional[str],
Rui Qiao's avatar
Rui Qiao committed
451
    ) -> Generator[EngineZmqAddresses, None, None]:
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

        For DP>1 with internal loadbalancing this is with the shared front-end
        process which may reside on a different node.

        For DP>1 with external loadbalancing, two handshakes are performed:
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
        with the exception of the rank 0 engine itself which doesn't require
        the second handshake.

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
473
        input_ctx = zmq.Context()
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        is_local = local_client and client_handshake_address is None
        handshake = self._perform_handshake(input_ctx, handshake_address,
                                            identity, is_local, vllm_config,
                                            vllm_config.parallel_config)
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
            local_handshake = self._perform_handshake(
                input_ctx, client_handshake_address, identity, local_client,
                vllm_config)
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
        parallel_config_to_update: Optional[ParallelConfig] = None,
    ) -> Generator[EngineZmqAddresses, None, None]:
        with make_zmq_socket(ctx,
Rui Qiao's avatar
Rui Qiao committed
504
505
506
507
508
509
                             handshake_address,
                             zmq.DEALER,
                             identity=identity,
                             linger=5000,
                             bind=False) as handshake_socket:
            # Register engine with front-end.
510
511
            addresses = self.startup_handshake(handshake_socket, local_client,
                                               parallel_config_to_update)
Rui Qiao's avatar
Rui Qiao committed
512
513
514
515
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
516
517
518
519
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
Rui Qiao's avatar
Rui Qiao committed
520
521
522
            handshake_socket.send(
                msgspec.msgpack.encode({
                    "status": "READY",
523
                    "local": local_client,
Rui Qiao's avatar
Rui Qiao committed
524
                    "num_gpu_blocks": num_gpu_blocks,
525
                    "dp_stats_address": dp_stats_address,
Rui Qiao's avatar
Rui Qiao committed
526
527
                }))

528
    @staticmethod
529
    def startup_handshake(
530
531
532
533
        handshake_socket: zmq.Socket,
        local_client: bool,
        parallel_config: Optional[ParallelConfig] = None,
    ) -> EngineZmqAddresses:
534
535

        # Send registration message.
536
        handshake_socket.send(
537
538
            msgspec.msgpack.encode({
                "status": "HELLO",
539
                "local": local_client,
540
541
542
543
            }))

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
544
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
545
546
547
            raise RuntimeError("Did not receive response from front-end "
                               f"process within {HANDSHAKE_TIMEOUT_MINS} "
                               f"minutes")
548
549
550
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
            init_bytes, type=EngineHandshakeMetadata)
551
552
        logger.debug("Received init message: %s", init_message)

553
554
555
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
556

557
        return init_message.addresses
558
559

    @staticmethod
560
561
562
563
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        **kwargs):
564
565
        """Launch EngineCore busy loop in background process."""

566
567
568
569
570
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

571
572
573
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

574
575
576
577
578
579
580
581
582
583
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

584
        engine_core: Optional[EngineCoreProc] = None
585
        try:
586
587
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
588
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
589
590
591
592
593
594
595
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
                engine_core = EngineCoreProc(*args, **kwargs)

596
597
            engine_core.run_busy_loop()

598
        except SystemExit:
599
            logger.debug("EngineCore exiting.")
600
            raise
601
602
603
604
605
606
607
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
608
609
610
611
        finally:
            if engine_core is not None:
                engine_core.shutdown()

612
613
614
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

615
616
617
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

618
619
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
620
            # 1) Poll the input queue until there is work to do.
621
622
623
624
625
626
627
628
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
629
        while not self.engines_running and not self.scheduler.has_requests():
630
631
632
633
634
635
636
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
637
            logger.debug("EngineCore loop active.")
638
639
640
641
642
643

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

644
    def _process_engine_step(self) -> bool:
645
646
647
        """Called only when there are unfinished local requests."""

        # Step the engine core.
648
        outputs, model_executed = self.step_fn()
649
        # Put EngineCoreOutputs into the output queue.
650
651
        for output in (outputs.items() if outputs else ()):
            self.output_queue.put_nowait(output)
652

653
654
        return model_executed

655
656
657
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
658

659
        if request_type == EngineCoreRequestType.ADD:
660
            self.add_request(request)
661
        elif request_type == EngineCoreRequestType.ABORT:
662
            self.abort_requests(request)
663
        elif request_type == EngineCoreRequestType.UTILITY:
664
            client_idx, call_id, method_name, args = request
665
666
667
668
669
670
671
672
673
674
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
675
                (client_idx, EngineCoreOutputs(utility_output=output)))
676
677
678
679
680
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
            logger.error("Unrecognized input request type encountered: %s",
                         request_type)
681
682
683
684
685
686
687
688
689
690
691
692
693
694

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
695

696
697
698
699
700
701
702
703
704
705
706
707
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
            logger.fatal("vLLM shutdown signal from EngineCore failed "
                         "to send. Please report this issue.")

708
709
710
    def process_input_sockets(self, input_addresses: list[str],
                              coord_input_address: Optional[str],
                              identity: bytes):
711
712
713
        """Input socket IO thread."""

        # Msgpack serialization decoding.
714
715
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
716

717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx,
                                    input_address,
                                    zmq.DEALER,
                                    identity=identity,
                                    bind=False))
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
                    make_zmq_socket(ctx,
                                    coord_input_address,
                                    zmq.XSUB,
                                    identity=identity,
                                    bind=False))
                # Send subscription message to coordinator.
                coord_socket.send(b'\x01')

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
                input_socket.send(b'')
                poller.register(input_socket, zmq.POLLIN)
            if coord_socket is not None:
                poller.register(coord_socket, zmq.POLLIN)
749

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
                    type_frame, *data_frames = input_socket.recv_multipart(
                        copy=False)
                    request_type = EngineCoreRequestType(
                        bytes(type_frame.buffer))

                    # Deserialize the request data.
                    decoder = add_request_decoder if (
                        request_type
                        == EngineCoreRequestType.ADD) else generic_decoder
                    request = decoder.decode(data_frames)

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

    def process_output_sockets(self, output_paths: list[str],
                               coord_output_path: Optional[str],
                               engine_index: int):
770
771
772
        """Output socket IO thread."""

        # Msgpack serialization encoding.
773
        encoder = MsgpackEncoder()
774
775
776
777
778
779
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
780

781
782
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
783
784
785
786
787
788
789
790
791
792
793
794
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
                for output_path in output_paths
            ]
            coord_socket = stack.enter_context(
                make_zmq_socket(
                    ctx, coord_output_path, zmq.PUSH, bind=False,
                    linger=4000)) if coord_output_path is not None else None
            max_reuse_bufs = len(sockets) + 1

795
            while True:
796
797
798
799
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
800
                    break
801
802
                assert not isinstance(output, bytes)
                client_index, outputs = output
803
                outputs.engine_index = engine_index
804

805
806
807
808
809
810
811
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

812
813
814
815
816
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
817
                buffers = encoder.encode_into(outputs, buffer)
818
819
820
                tracker = sockets[client_index].send_multipart(buffers,
                                                               copy=False,
                                                               track=True)
821
822
823
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
824
825
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
826
                    reuse_buffers.append(buffer)
827
828
829
830
831
832
833
834
835


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
836
        local_client: bool,
837
        handshake_address: str,
838
839
        executor_class: type[Executor],
        log_stats: bool,
840
        client_handshake_address: Optional[str] = None,
841
    ):
Rui Qiao's avatar
Rui Qiao committed
842
        self._decorate_logs()
843

844
845
846
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
        self.counter = 0
847
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
848
        self.last_counts = (0, 0)
849
850
851

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
852
853
854
        super().__init__(vllm_config, local_client, handshake_address,
                         executor_class, log_stats, client_handshake_address,
                         dp_rank)
855

Rui Qiao's avatar
Rui Qiao committed
856
857
858
859
860
861
862
863
864
    def _decorate_logs(self):
        # Add process-specific prefix to stdout and stderr before
        # we initialize the engine.
        from multiprocessing import current_process
        process_name = current_process().name
        pid = os.getpid()
        _add_prefix(sys.stdout, process_name, pid)
        _add_prefix(sys.stderr, process_name, pid)

865
866
867
    def _init_data_parallel(self, vllm_config: VllmConfig):

        # Configure GPUs and stateless process group for data parallel.
868
        dp_rank = vllm_config.parallel_config.data_parallel_rank
869
        dp_size = vllm_config.parallel_config.data_parallel_size
870
871
872
873
874
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

875
876
877
878
879
880
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
881
            logger.debug("Setting kv_transfer_config.engine_id to %s",
882
883
                         vllm_config.kv_transfer_config.engine_id)

884
        from vllm.platforms import current_platform
885
        device_control_env_var = current_platform.device_control_env_var
886
        world_size = vllm_config.parallel_config.world_size
887
888
889
890
891
892
893
894
895
896
897
898
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
            os.environ[device_control_env_var] = ",".join(
                str(current_platform.device_id_to_physical_device_id(i))
                for i in range(local_dp_rank *
                               world_size, (local_dp_rank + 1) * world_size))
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
                f"base value: \"{os.getenv(device_control_env_var)}\"") from e
899

900
        self.dp_rank = dp_rank
901
902
903
904
905
906
907
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

908
    def add_request(self, request: EngineCoreRequest):
909
        if self.has_coordinator and request.current_wave != self.current_wave:
910
911
912
913
914
915
            if request.current_wave > self.current_wave:
                self.current_wave = request.current_wave
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
916
                    (-1, EngineCoreOutputs(start_wave=self.current_wave)))
917
918
919
920
921
922

        super().add_request(request)

    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        if request_type == EngineCoreRequestType.START_DP_WAVE:
923
924
925
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
                    new_wave >= self.current_wave):
926
927
928
929
930
931
932
933
                self.current_wave = new_wave
                if not self.engines_running:
                    logger.debug("EngineCore starting idle loop for wave %d.",
                                 new_wave)
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

934
    def _maybe_publish_request_counts(self):
935
        if not self.publish_dp_lb_stats:
936
937
938
939
940
941
942
943
944
945
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
            stats = SchedulerStats(*counts)
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(scheduler_stats=stats)))

946
947
948
949
950
951
952
953
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

954
955
            # 2) Step the engine core.
            executed = self._process_engine_step()
956
957
            self._maybe_publish_request_counts()

958
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
959
960
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
961
962
963
                    # All engines are idle.
                    continue

964
965
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
966
967
968
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
969
            self.engines_running = self._has_global_unfinished_reqs(
970
971
                local_unfinished_reqs)

972
            if not self.engines_running:
973
                if self.dp_rank == 0 or not self.has_coordinator:
974
975
976
                    # Notify client that we are pausing the loop.
                    logger.debug("Wave %d finished, pausing engine loop.",
                                 self.current_wave)
977
978
979
980
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
981
                    self.output_queue.put_nowait(
982
                        (client_index,
983
                         EngineCoreOutputs(wave_complete=self.current_wave)))
984
                self.current_wave += 1
985
986
987

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

988
        # Optimization - only perform finish-sync all-reduce every 32 steps.
989
        self.counter += 1
990
        if self.counter != 32:
991
992
993
994
995
            return True
        self.counter = 0

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
996
997
998
999
1000
1001
1002
1003
1004
1005


class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1006
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
        vllm_config.parallel_config.data_parallel_rank_local = \
            local_dp_rank

        # Ray sets CUDA_VISIBLE_DEVICES to empty string,
        # we clean this up to be able to properly initialize
        # data parallel groups.
        del os.environ['CUDA_VISIBLE_DEVICES']

1023
        super().__init__(vllm_config, local_client, "", executor_class,
Rui Qiao's avatar
Rui Qiao committed
1024
1025
1026
1027
1028
1029
                         log_stats)

    def _decorate_logs(self):
        pass

    @contextmanager
1030
1031
1032
    def _perform_handshakes(self, handshake_address: str, identity: bytes,
                            local_client: bool, vllm_config: VllmConfig,
                            client_handshake_address: Optional[str]):
Rui Qiao's avatar
Rui Qiao committed
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()