core.py 48.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
import sys
7
8
import threading
import time
9
from collections import deque
Rui Qiao's avatar
Rui Qiao committed
10
from collections.abc import Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, Callable, Optional, TypeVar, Union
16

17
import msgspec
18
19
import zmq

20
21
22
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
23
from vllm.logger import init_logger
24
from vllm.logging_utils.dump_input import dump_engine_exception
25
from vllm.lora.request import LoRARequest
26
from vllm.tasks import POOLING_TASKS, SupportedTask
27
28
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
29
30
from vllm.utils import (bind_process_name, make_zmq_socket,
                        resolve_obj_by_qualname)
31
32
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
33
from vllm.v1.core.sched.interface import SchedulerInterface
34
35
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
36
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
37
38
39
                            EngineCoreRequestType,
                            ReconfigureDistributedRequest, ReconfigureRankType,
                            UtilityOutput)
40
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
41
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
42
from vllm.v1.executor.abstract import Executor
43
from vllm.v1.kv_cache_interface import KVCacheConfig
44
from vllm.v1.metrics.stats import SchedulerStats
45
from vllm.v1.outputs import ModelRunnerOutput
46
from vllm.v1.request import Request, RequestStatus
47
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
48
from vllm.v1.structured_output import StructuredOutputManager
49
50
51
52
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

53
POLLING_TIMEOUT_S = 2.5
54
HANDSHAKE_TIMEOUT_MINS = 5
55

56
57
_R = TypeVar('_R')  # Return type for collective_rpc

58
59
60
61

class EngineCore:
    """Inner loop of vLLM's Engine."""

62
63
64
65
66
    def __init__(self,
                 vllm_config: VllmConfig,
                 executor_class: type[Executor],
                 log_stats: bool,
                 executor_fail_callback: Optional[Callable] = None):
67

68
69
70
71
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
        load_general_plugins()

72
        self.vllm_config = vllm_config
73
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
74
75
                    VLLM_VERSION, vllm_config)

76
77
        self.log_stats = log_stats

78
79
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
80
81
82
        if executor_fail_callback is not None:
            self.model_executor.register_failure_callback(
                executor_fail_callback)
83

84
85
        self.available_gpu_memory_for_kv_cache = -1

86
        # Setup KV Caches and update CacheConfig after profiling.
87
88
89
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

90
91
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
92
93
        self.collective_rpc("initialize_cache",
                            args=(num_gpu_blocks, num_cpu_blocks))
94

95
96
        self.structured_output_manager = StructuredOutputManager(vllm_config)

97
        # Setup scheduler.
98
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
99
100
101
102
103
104
105
106
107
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
108
109
110
111
112
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
113

114
        self.scheduler: SchedulerInterface = Scheduler(
115
            vllm_config=vllm_config,
116
117
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
118
119
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
120
            log_stats=self.log_stats,
121
        )
122

123
        # Setup MM Input Mapper.
124
        self.mm_input_cache_server = MirroredProcessingCache(
125
            vllm_config.model_config)
126

127
128
129
130
131
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
132
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
133
134
135
136
137
138
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

139
140
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
141
        start = time.time()
142

143
        # Get all kv cache needed by the model
144
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
145

146
147
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
                self.available_gpu_memory_for_kv_cache = \
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
                available_gpu_memory = [
                    self.available_gpu_memory_for_kv_cache
                ] * len(kv_cache_specs)
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
                available_gpu_memory = (
                    self.model_executor.determine_available_memory())
                self.available_gpu_memory_for_kv_cache = \
                    available_gpu_memory[0]
163
164
165
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
166

167
        assert len(kv_cache_specs) == len(available_gpu_memory)
168
        # Get the kv cache tensor size
169
170
171
172
173
174
175
176
177
178
179
180
181
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
182
        # an arbitrary one to initialize the scheduler.
183
184
185
186
187
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
188
        num_cpu_blocks = 0
189
        scheduler_kv_cache_config = kv_cache_configs[0]
190
191

        # Initialize kv cache and warmup the execution
192
        self.model_executor.initialize_from_config(kv_cache_configs)
193

194
195
196
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
197
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
198

199
200
201
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

202
203
    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
204
        if pooling_params := request.pooling_params:
205
206
207
208
209
            supported_pooling_tasks = [
                task for task in self.get_supported_tasks()
                if task in POOLING_TASKS
            ]

210
211
212
            if pooling_params.task not in supported_pooling_tasks:
                raise ValueError(f"Unsupported task: {pooling_params.task!r} "
                                 f"Supported tasks: {supported_pooling_tasks}")
213
214

        if request.mm_hashes is not None:
215
216
217
218
219
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
220
            assert request.mm_inputs is not None
221
            request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
222
                request.mm_inputs, request.mm_hashes)
223

224
        req = Request.from_engine_core_request(request)
225
226
        if req.use_structured_output:
            # Start grammar compilation asynchronously
227
            self.structured_output_manager.grammar_init(req)
228

229
230
231
232
        if req.kv_transfer_params is not None and (
                not self.scheduler.get_kv_connector()):
            logger.warning("Got kv_transfer_params, but no KVConnector found. "
                           "Disabling KVTransfer for this request.")
Robert Shaw's avatar
Robert Shaw committed
233

234
235
        self.scheduler.add_request(req)

236
    def abort_requests(self, request_ids: list[str]):
237
238
239
240
241
242
243
244
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

245
246
247
248
249
250
    def execute_model_with_error_logging(
        self,
        model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
        scheduler_output: SchedulerOutput,
    ) -> ModelRunnerOutput:
        """Execute the model and log detailed info on failure."""
251
        try:
252
            return model_fn(scheduler_output)
253
254
255
256
257
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

258
259
260
261
262
            # NOTE: This method is exception-free
            dump_engine_exception(self.vllm_config, scheduler_output,
                                  self.scheduler.make_stats())
            raise err

263
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
264
265
266
267
268
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
269

270
271
272
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
273
            return {}, False
274
        scheduler_output = self.scheduler.schedule()
275
276
277
        model_output = self.execute_model_with_error_logging(
            self.model_executor.execute_model,  # type: ignore
            scheduler_output)
278
        engine_core_outputs = self.scheduler.update_from_output(
279
            scheduler_output, model_output)  # type: ignore
280

281
282
        return (engine_core_outputs,
                scheduler_output.total_num_scheduled_tokens > 0)
283

284
    def step_with_batch_queue(
285
            self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
286
287
288
289
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
290
291
292
293
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
294
295
296
297
298
299
300
301
302
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
303
304
305
306
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
        if not self.batch_queue.full():
307
308
309
310
311
312
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

313
314
315
316
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
317
        # block until the first batch in the job queue is finished.
318
319
320
321
        # TODO(comaniac): Ideally we should peek the first batch in the
        # job queue to check if it's finished before scheduling a new batch,
        # but peeking the first element in a queue is not thread-safe,
        # so we need more work.
322
323
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
324

325
            # Blocking until the first result is available.
326
327
328
            model_output = self.execute_model_with_error_logging(
                lambda _: future.result(), scheduler_output)

329
            self.batch_queue.task_done()
330
331
            engine_core_outputs = (self.scheduler.update_from_output(
                scheduler_output, model_output))
332

333
        return engine_core_outputs, scheduled_batch
334

335
    def shutdown(self):
336
        self.structured_output_manager.clear_backend()
337
338
        if self.model_executor:
            self.model_executor.shutdown()
339
340
        if self.scheduler:
            self.scheduler.shutdown()
341

342
    def profile(self, is_start: bool = True):
343
        self.model_executor.profile(is_start)
344

345
346
347
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
348
        if self.scheduler.has_unfinished_requests():
349
350
351
352
353
            logger.warning("Resetting the multi-modal cache when requests are "
                           "in progress may lead to desynced internal caches.")

        self.mm_input_cache_server.reset()

354
355
356
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

357
358
359
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

360
361
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
362

363
364
365
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

366
367
368
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

369
370
371
372
373
374
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

375
    def list_loras(self) -> set[int]:
376
377
378
379
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
380

381
382
383
384
385
386
387
388
389
390
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_executor.save_sharded_state(path=path,
                                               pattern=pattern,
                                               max_size=max_size)

391
392
393
394
395
396
397
398
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

399
400
401
402
403
404
405
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

406
407
408
409

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

410
411
    ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'

412
413
    def __init__(
        self,
414
        vllm_config: VllmConfig,
415
        local_client: bool,
416
        handshake_address: str,
417
        executor_class: type[Executor],
418
        log_stats: bool,
419
        client_handshake_address: Optional[str] = None,
420
        engine_index: int = 0,
421
    ):
422
        bind_process_name(self.__class__.__name__, f"{engine_index}")
Rui Qiao's avatar
Rui Qiao committed
423
424
425
426
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
        self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
                                              bytes]]()
        executor_fail_callback = lambda: self.input_queue.put_nowait(
427
428
            (EngineCoreRequestType.EXECUTOR_FAILED, b''))

Rui Qiao's avatar
Rui Qiao committed
429
430
431
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
432

433
434
435
        with self._perform_handshakes(handshake_address, identity,
                                      local_client, vllm_config,
                                      client_handshake_address) as addresses:
436
            self.client_count = len(addresses.outputs)
437
438

            # Set up data parallel environment.
439
            self.has_coordinator = addresses.coordinator_output is not None
440
441
442
443
444
445
446
447
            self.frontend_stats_publish_address = (
                addresses.frontend_stats_publish_address)
            # Only publish request queue stats to coordinator for "internal"
            # LB mode.
            self.publish_dp_lb_stats = (
                self.has_coordinator
                and not vllm_config.parallel_config.data_parallel_external_lb)

448
449
450
451
452
            self._init_data_parallel(vllm_config)

            super().__init__(vllm_config, executor_class, log_stats,
                             executor_fail_callback)

Rui Qiao's avatar
Rui Qiao committed
453
454
        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)
455

456
457
458
459
460
461
462
463
464
465
466
467
        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
        threading.Thread(target=self.process_input_sockets,
                         args=(addresses.inputs, addresses.coordinator_input,
                               identity),
                         daemon=True).start()
        self.output_thread = threading.Thread(
            target=self.process_output_sockets,
            args=(addresses.outputs, addresses.coordinator_output,
Rui Qiao's avatar
Rui Qiao committed
468
                  self.engine_index),
469
470
            daemon=True)
        self.output_thread.start()
471

Rui Qiao's avatar
Rui Qiao committed
472
    @contextmanager
473
474
475
476
477
478
479
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
        client_handshake_address: Optional[str],
Rui Qiao's avatar
Rui Qiao committed
480
    ) -> Generator[EngineZmqAddresses, None, None]:
481
482
483
484
485
486
487
488
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

        For DP>1 with internal loadbalancing this is with the shared front-end
        process which may reside on a different node.

489
490
        For DP>1 with external or hybrid loadbalancing, two handshakes are
        performed:
491
492
493
494
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
495
496
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
497
498
499
500
501
502

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
503
        input_ctx = zmq.Context()
504
        is_local = local_client and client_handshake_address is None
505
        headless = not local_client
506
        handshake = self._perform_handshake(input_ctx, handshake_address,
507
508
                                            identity, is_local, headless,
                                            vllm_config,
509
510
511
512
513
                                            vllm_config.parallel_config)
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
514
            assert local_client
515
            local_handshake = self._perform_handshake(
516
                input_ctx, client_handshake_address, identity, True, False,
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
                vllm_config)
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
533
        headless: bool,
534
535
536
537
        vllm_config: VllmConfig,
        parallel_config_to_update: Optional[ParallelConfig] = None,
    ) -> Generator[EngineZmqAddresses, None, None]:
        with make_zmq_socket(ctx,
Rui Qiao's avatar
Rui Qiao committed
538
539
540
541
542
543
                             handshake_address,
                             zmq.DEALER,
                             identity=identity,
                             linger=5000,
                             bind=False) as handshake_socket:
            # Register engine with front-end.
544
            addresses = self.startup_handshake(handshake_socket, local_client,
545
                                               headless,
546
                                               parallel_config_to_update)
Rui Qiao's avatar
Rui Qiao committed
547
548
549
550
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
551
552
553
554
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
Rui Qiao's avatar
Rui Qiao committed
555
556
557
            handshake_socket.send(
                msgspec.msgpack.encode({
                    "status": "READY",
558
                    "local": local_client,
559
                    "headless": headless,
Rui Qiao's avatar
Rui Qiao committed
560
                    "num_gpu_blocks": num_gpu_blocks,
561
                    "dp_stats_address": dp_stats_address,
Rui Qiao's avatar
Rui Qiao committed
562
563
                }))

564
    @staticmethod
565
    def startup_handshake(
566
567
        handshake_socket: zmq.Socket,
        local_client: bool,
568
        headless: bool,
569
570
        parallel_config: Optional[ParallelConfig] = None,
    ) -> EngineZmqAddresses:
571
572

        # Send registration message.
573
        handshake_socket.send(
574
575
            msgspec.msgpack.encode({
                "status": "HELLO",
576
                "local": local_client,
577
                "headless": headless,
578
579
580
581
            }))

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
582
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
583
584
585
            raise RuntimeError("Did not receive response from front-end "
                               f"process within {HANDSHAKE_TIMEOUT_MINS} "
                               f"minutes")
586
587
588
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
            init_bytes, type=EngineHandshakeMetadata)
589
590
        logger.debug("Received init message: %s", init_message)

591
592
593
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
594

595
        return init_message.addresses
596
597

    @staticmethod
598
599
600
601
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        **kwargs):
602
603
        """Launch EngineCore busy loop in background process."""

604
605
606
607
608
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

609
610
611
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

612
613
614
615
616
617
618
619
620
621
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

622
        engine_core: Optional[EngineCoreProc] = None
623
        try:
624
625
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
626
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
627
628
629
630
631
632
633
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
                engine_core = EngineCoreProc(*args, **kwargs)

634
635
            engine_core.run_busy_loop()

636
        except SystemExit:
637
            logger.debug("EngineCore exiting.")
638
            raise
639
640
641
642
643
644
645
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
646
647
648
649
        finally:
            if engine_core is not None:
                engine_core.shutdown()

650
651
652
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

653
654
655
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

656
657
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
658
            # 1) Poll the input queue until there is work to do.
659
660
661
662
663
664
665
666
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
667
        while not self.engines_running and not self.scheduler.has_requests():
668
669
670
671
672
673
674
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
675
            logger.debug("EngineCore loop active.")
676
677
678
679
680
681

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

682
    def _process_engine_step(self) -> bool:
683
684
685
        """Called only when there are unfinished local requests."""

        # Step the engine core.
686
        outputs, model_executed = self.step_fn()
687
        # Put EngineCoreOutputs into the output queue.
688
689
        for output in (outputs.items() if outputs else ()):
            self.output_queue.put_nowait(output)
690

691
692
        return model_executed

693
694
695
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
696

697
        if request_type == EngineCoreRequestType.ADD:
698
            self.add_request(request)
699
        elif request_type == EngineCoreRequestType.ABORT:
700
            self.abort_requests(request)
701
        elif request_type == EngineCoreRequestType.UTILITY:
702
            client_idx, call_id, method_name, args = request
703
704
705
706
707
708
709
710
711
712
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
713
                (client_idx, EngineCoreOutputs(utility_output=output)))
714
715
716
717
718
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
            logger.error("Unrecognized input request type encountered: %s",
                         request_type)
719
720
721
722
723
724
725
726
727
728
729
730
731
732

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
733

734
735
736
737
738
739
740
741
742
743
744
745
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
            logger.fatal("vLLM shutdown signal from EngineCore failed "
                         "to send. Please report this issue.")

746
747
748
    def process_input_sockets(self, input_addresses: list[str],
                              coord_input_address: Optional[str],
                              identity: bytes):
749
750
751
        """Input socket IO thread."""

        # Msgpack serialization decoding.
752
753
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
754

755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx,
                                    input_address,
                                    zmq.DEALER,
                                    identity=identity,
                                    bind=False))
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
                    make_zmq_socket(ctx,
                                    coord_input_address,
                                    zmq.XSUB,
                                    identity=identity,
                                    bind=False))
                # Send subscription message to coordinator.
                coord_socket.send(b'\x01')

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
                input_socket.send(b'')
                poller.register(input_socket, zmq.POLLIN)
            if coord_socket is not None:
                poller.register(coord_socket, zmq.POLLIN)
787

788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
                    type_frame, *data_frames = input_socket.recv_multipart(
                        copy=False)
                    request_type = EngineCoreRequestType(
                        bytes(type_frame.buffer))

                    # Deserialize the request data.
                    decoder = add_request_decoder if (
                        request_type
                        == EngineCoreRequestType.ADD) else generic_decoder
                    request = decoder.decode(data_frames)

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

    def process_output_sockets(self, output_paths: list[str],
                               coord_output_path: Optional[str],
                               engine_index: int):
808
809
810
        """Output socket IO thread."""

        # Msgpack serialization encoding.
811
        encoder = MsgpackEncoder()
812
813
814
815
816
817
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
818

819
820
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
821
822
823
824
825
826
827
828
829
830
831
832
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
                for output_path in output_paths
            ]
            coord_socket = stack.enter_context(
                make_zmq_socket(
                    ctx, coord_output_path, zmq.PUSH, bind=False,
                    linger=4000)) if coord_output_path is not None else None
            max_reuse_bufs = len(sockets) + 1

833
            while True:
834
835
836
837
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
838
                    break
839
840
                assert not isinstance(output, bytes)
                client_index, outputs = output
841
                outputs.engine_index = engine_index
842

843
844
845
846
847
848
849
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

850
851
852
853
854
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
855
                buffers = encoder.encode_into(outputs, buffer)
856
857
858
                tracker = sockets[client_index].send_multipart(buffers,
                                                               copy=False,
                                                               track=True)
859
860
861
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
862
863
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
864
                    reuse_buffers.append(buffer)
865
866
867
868
869
870
871
872
873


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
874
        local_client: bool,
875
        handshake_address: str,
876
877
        executor_class: type[Executor],
        log_stats: bool,
878
        client_handshake_address: Optional[str] = None,
879
    ):
Rui Qiao's avatar
Rui Qiao committed
880
        self._decorate_logs()
881

882
883
884
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
        self.counter = 0
885
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
886
        self.last_counts = (0, 0)
887
888
889

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
890
891
892
        super().__init__(vllm_config, local_client, handshake_address,
                         executor_class, log_stats, client_handshake_address,
                         dp_rank)
893

Rui Qiao's avatar
Rui Qiao committed
894
895
896
897
898
899
900
901
902
    def _decorate_logs(self):
        # Add process-specific prefix to stdout and stderr before
        # we initialize the engine.
        from multiprocessing import current_process
        process_name = current_process().name
        pid = os.getpid()
        _add_prefix(sys.stdout, process_name, pid)
        _add_prefix(sys.stderr, process_name, pid)

903
904
905
    def _init_data_parallel(self, vllm_config: VllmConfig):

        # Configure GPUs and stateless process group for data parallel.
906
        dp_rank = vllm_config.parallel_config.data_parallel_rank
907
        dp_size = vllm_config.parallel_config.data_parallel_size
908
909
910
911
912
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

913
914
915
916
917
918
919
920
921
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
            logger.debug("Setting kv_transfer_config.engine_id to %s",
                         vllm_config.kv_transfer_config.engine_id)

922
        self.dp_rank = dp_rank
923
924
925
926
927
928
929
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

930
    def add_request(self, request: EngineCoreRequest):
931
        if self.has_coordinator and request.current_wave != self.current_wave:
932
933
934
935
936
937
            if request.current_wave > self.current_wave:
                self.current_wave = request.current_wave
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
938
                    (-1, EngineCoreOutputs(start_wave=self.current_wave)))
939
940
941
942
943
944

        super().add_request(request)

    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        if request_type == EngineCoreRequestType.START_DP_WAVE:
945
946
947
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
                    new_wave >= self.current_wave):
948
949
950
951
952
953
954
955
                self.current_wave = new_wave
                if not self.engines_running:
                    logger.debug("EngineCore starting idle loop for wave %d.",
                                 new_wave)
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

956
    def _maybe_publish_request_counts(self):
957
        if not self.publish_dp_lb_stats:
958
959
960
961
962
963
964
965
966
967
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
            stats = SchedulerStats(*counts)
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(scheduler_stats=stats)))

968
969
970
971
972
973
974
975
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

976
977
            # 2) Step the engine core.
            executed = self._process_engine_step()
978
979
            self._maybe_publish_request_counts()

980
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
981
982
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
983
984
985
                    # All engines are idle.
                    continue

986
987
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
988
989
990
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
991
            self.engines_running = self._has_global_unfinished_reqs(
992
993
                local_unfinished_reqs)

994
            if not self.engines_running:
995
                if self.dp_rank == 0 or not self.has_coordinator:
996
997
998
                    # Notify client that we are pausing the loop.
                    logger.debug("Wave %d finished, pausing engine loop.",
                                 self.current_wave)
999
1000
1001
1002
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1003
                    self.output_queue.put_nowait(
1004
                        (client_index,
1005
                         EngineCoreOutputs(wave_complete=self.current_wave)))
1006
                self.current_wave += 1
1007
1008
1009

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

1010
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1011
        self.counter += 1
1012
        if self.counter != 32:
1013
1014
1015
1016
1017
            return True
        self.counter = 0

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1018

1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
    def reinitialize_distributed(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
        parallel_config.data_parallel_size = \
            reconfig_request.new_data_parallel_size
        if reconfig_request.new_data_parallel_rank != -1:
            parallel_config.data_parallel_rank = \
                reconfig_request.new_data_parallel_rank
        # local rank specifies device visibility, it should not be changed
        assert reconfig_request.new_data_parallel_rank_local == \
            ReconfigureRankType.KEEP_CURRENT_RANK
        parallel_config.data_parallel_master_ip = \
            reconfig_request.new_data_parallel_master_ip
        parallel_config.data_parallel_master_port = \
            reconfig_request.new_data_parallel_master_port
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
        reconfig_request.new_data_parallel_master_port = \
            parallel_config.data_parallel_master_port

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
                self.dp_group, self.available_gpu_memory_for_kv_cache)
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
        if reconfig_request.new_data_parallel_rank == \
        ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
            logger.info("Distributed environment reinitialized for DP rank %s",
                        self.dp_rank)

Rui Qiao's avatar
Rui Qiao committed
1063
1064
1065
1066
1067
1068
1069
1070
1071

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1072
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
        vllm_config.parallel_config.data_parallel_rank_local = \
            local_dp_rank

1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1094
1095
1096
1097
1098
1099
1100
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1101
        self._set_cuda_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1102

1103
        super().__init__(vllm_config, local_client, "", executor_class,
Rui Qiao's avatar
Rui Qiao committed
1104
1105
                         log_stats)

1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
    def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
                                  local_dp_rank: int):
        from vllm.platforms import current_platform
        device_control_env_var = current_platform.device_control_env_var
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
            os.environ[device_control_env_var] = ",".join(
                str(current_platform.device_id_to_physical_device_id(i))
                for i in range(local_dp_rank *
                               world_size, (local_dp_rank + 1) * world_size))
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
                f"base value: \"{os.getenv(device_control_env_var)}\"") from e

Rui Qiao's avatar
Rui Qiao committed
1124
1125
1126
1127
    def _decorate_logs(self):
        pass

    @contextmanager
1128
1129
1130
    def _perform_handshakes(self, handshake_address: str, identity: bytes,
                            local_client: bool, vllm_config: VllmConfig,
                            client_handshake_address: Optional[str]):
Rui Qiao's avatar
Rui Qiao committed
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()