core.py 49.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
import sys
7
8
import threading
import time
9
from collections import deque
Rui Qiao's avatar
Rui Qiao committed
10
from collections.abc import Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, Callable, Optional, TypeVar, Union
16

17
import msgspec
18
19
import zmq

20
21
22
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
23
from vllm.logger import init_logger
24
from vllm.logging_utils.dump_input import dump_engine_exception
25
from vllm.lora.request import LoRARequest
26
from vllm.tasks import POOLING_TASKS, SupportedTask
27
28
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
29
30
from vllm.utils import (make_zmq_socket, resolve_obj_by_qualname,
                        set_process_title)
31
32
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
33
from vllm.v1.core.sched.interface import SchedulerInterface
34
35
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
36
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
37
38
                            EngineCoreRequestType,
                            ReconfigureDistributedRequest, ReconfigureRankType,
39
                            UtilityOutput, UtilityResult)
40
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
41
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
42
from vllm.v1.executor.abstract import Executor
43
from vllm.v1.kv_cache_interface import KVCacheConfig
44
from vllm.v1.metrics.stats import SchedulerStats
45
from vllm.v1.outputs import ModelRunnerOutput
46
from vllm.v1.request import Request, RequestStatus
47
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
48
from vllm.v1.structured_output import StructuredOutputManager
49
50
51
52
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

53
POLLING_TIMEOUT_S = 2.5
54
HANDSHAKE_TIMEOUT_MINS = 5
55

56
57
_R = TypeVar('_R')  # Return type for collective_rpc

58
59
60
61

class EngineCore:
    """Inner loop of vLLM's Engine."""

62
63
64
65
66
    def __init__(self,
                 vllm_config: VllmConfig,
                 executor_class: type[Executor],
                 log_stats: bool,
                 executor_fail_callback: Optional[Callable] = None):
67

68
69
70
71
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
        load_general_plugins()

72
        self.vllm_config = vllm_config
73
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
74
75
                    VLLM_VERSION, vllm_config)

76
77
        self.log_stats = log_stats

78
79
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
80
81
82
        if executor_fail_callback is not None:
            self.model_executor.register_failure_callback(
                executor_fail_callback)
83

84
85
        self.available_gpu_memory_for_kv_cache = -1

86
        # Setup KV Caches and update CacheConfig after profiling.
87
88
89
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

90
91
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
92
93
        self.collective_rpc("initialize_cache",
                            args=(num_gpu_blocks, num_cpu_blocks))
94

95
96
        self.structured_output_manager = StructuredOutputManager(vllm_config)

97
        # Setup scheduler.
98
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
99
100
101
102
103
104
105
106
107
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
108
109
110
111
112
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
113

114
115
116
117
118
119
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

120
        self.scheduler: SchedulerInterface = Scheduler(
121
            vllm_config=vllm_config,
122
123
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
124
125
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
126
            log_stats=self.log_stats,
127
        )
128

129
        # Setup MM Input Mapper.
130
        self.mm_input_cache_server = MirroredProcessingCache(
131
            vllm_config.model_config)
132

133
134
135
136
137
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
138
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
139
140
141
142
143
144
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

145
146
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
147
        start = time.time()
148

149
        # Get all kv cache needed by the model
150
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
151

152
153
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
                self.available_gpu_memory_for_kv_cache = \
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
                available_gpu_memory = [
                    self.available_gpu_memory_for_kv_cache
                ] * len(kv_cache_specs)
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
                available_gpu_memory = (
                    self.model_executor.determine_available_memory())
                self.available_gpu_memory_for_kv_cache = \
                    available_gpu_memory[0]
169
170
171
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
172

173
        assert len(kv_cache_specs) == len(available_gpu_memory)
174
        # Get the kv cache tensor size
175
176
177
178
179
180
181
182
183
184
185
186
187
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
188
        # an arbitrary one to initialize the scheduler.
189
190
191
192
193
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
194
        num_cpu_blocks = 0
195
        scheduler_kv_cache_config = kv_cache_configs[0]
196
197

        # Initialize kv cache and warmup the execution
198
        self.model_executor.initialize_from_config(kv_cache_configs)
199

200
201
202
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
203
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
204

205
206
207
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

208
209
    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
210
211
212
213
214
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request.request_id)}")

215
        if pooling_params := request.pooling_params:
216
217
218
219
220
            supported_pooling_tasks = [
                task for task in self.get_supported_tasks()
                if task in POOLING_TASKS
            ]

221
222
223
            if pooling_params.task not in supported_pooling_tasks:
                raise ValueError(f"Unsupported task: {pooling_params.task!r} "
                                 f"Supported tasks: {supported_pooling_tasks}")
224
225

        if request.mm_hashes is not None:
226
227
228
229
230
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
231
            assert request.mm_inputs is not None
232
            request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
233
                request.mm_inputs, request.mm_hashes)
234

235
        req = Request.from_engine_core_request(request)
236
237
        if req.use_structured_output:
            # Start grammar compilation asynchronously
238
            self.structured_output_manager.grammar_init(req)
239

240
241
242
243
        if req.kv_transfer_params is not None and (
                not self.scheduler.get_kv_connector()):
            logger.warning("Got kv_transfer_params, but no KVConnector found. "
                           "Disabling KVTransfer for this request.")
Robert Shaw's avatar
Robert Shaw committed
244

245
246
        self.scheduler.add_request(req)

247
    def abort_requests(self, request_ids: list[str]):
248
249
250
251
252
253
254
255
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

256
257
258
259
260
261
    def execute_model_with_error_logging(
        self,
        model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
        scheduler_output: SchedulerOutput,
    ) -> ModelRunnerOutput:
        """Execute the model and log detailed info on failure."""
262
        try:
263
            return model_fn(scheduler_output)
264
265
266
267
268
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

269
270
271
272
273
            # NOTE: This method is exception-free
            dump_engine_exception(self.vllm_config, scheduler_output,
                                  self.scheduler.make_stats())
            raise err

274
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
275
276
277
278
279
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
280

281
282
283
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
284
            return {}, False
285
        scheduler_output = self.scheduler.schedule()
286
287
288
        model_output = self.execute_model_with_error_logging(
            self.model_executor.execute_model,  # type: ignore
            scheduler_output)
289
        engine_core_outputs = self.scheduler.update_from_output(
290
            scheduler_output, model_output)  # type: ignore
291

292
293
        return (engine_core_outputs,
                scheduler_output.total_num_scheduled_tokens > 0)
294

295
    def step_with_batch_queue(
296
            self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
297
298
299
300
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
301
302
303
304
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
305
306
307
308
309
310
311
312
313
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
314
315
316
317
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
        if not self.batch_queue.full():
318
319
320
321
322
323
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

324
325
326
327
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
328
        # block until the first batch in the job queue is finished.
329
330
331
332
        # TODO(comaniac): Ideally we should peek the first batch in the
        # job queue to check if it's finished before scheduling a new batch,
        # but peeking the first element in a queue is not thread-safe,
        # so we need more work.
333
334
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
335

336
            # Blocking until the first result is available.
337
338
339
            model_output = self.execute_model_with_error_logging(
                lambda _: future.result(), scheduler_output)

340
            self.batch_queue.task_done()
341
342
            engine_core_outputs = (self.scheduler.update_from_output(
                scheduler_output, model_output))
343

344
        return engine_core_outputs, scheduled_batch
345

346
    def shutdown(self):
347
        self.structured_output_manager.clear_backend()
348
349
        if self.model_executor:
            self.model_executor.shutdown()
350
351
        if self.scheduler:
            self.scheduler.shutdown()
352

353
    def profile(self, is_start: bool = True):
354
        self.model_executor.profile(is_start)
355

356
357
358
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
359
        if self.scheduler.has_unfinished_requests():
360
361
362
363
364
            logger.warning("Resetting the multi-modal cache when requests are "
                           "in progress may lead to desynced internal caches.")

        self.mm_input_cache_server.reset()

365
366
367
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

368
369
370
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

371
372
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
373

374
375
376
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

377
378
379
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

380
381
382
383
384
385
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

386
    def list_loras(self) -> set[int]:
387
388
389
390
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
391

392
393
394
395
396
397
398
399
400
401
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_executor.save_sharded_state(path=path,
                                               pattern=pattern,
                                               max_size=max_size)

402
403
404
405
406
407
408
409
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

410
411
412
413
414
415
416
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

417
418
419
420

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

421
422
    ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'

423
424
    def __init__(
        self,
425
        vllm_config: VllmConfig,
426
        local_client: bool,
427
        handshake_address: str,
428
        executor_class: type[Executor],
429
        log_stats: bool,
430
        client_handshake_address: Optional[str] = None,
431
        engine_index: int = 0,
432
    ):
Rui Qiao's avatar
Rui Qiao committed
433
434
435
436
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
        self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
                                              bytes]]()
        executor_fail_callback = lambda: self.input_queue.put_nowait(
437
438
            (EngineCoreRequestType.EXECUTOR_FAILED, b''))

Rui Qiao's avatar
Rui Qiao committed
439
440
441
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
442

443
444
445
        with self._perform_handshakes(handshake_address, identity,
                                      local_client, vllm_config,
                                      client_handshake_address) as addresses:
446
            self.client_count = len(addresses.outputs)
447
448

            # Set up data parallel environment.
449
            self.has_coordinator = addresses.coordinator_output is not None
450
451
452
453
454
455
456
457
            self.frontend_stats_publish_address = (
                addresses.frontend_stats_publish_address)
            # Only publish request queue stats to coordinator for "internal"
            # LB mode.
            self.publish_dp_lb_stats = (
                self.has_coordinator
                and not vllm_config.parallel_config.data_parallel_external_lb)

458
459
460
461
462
            self._init_data_parallel(vllm_config)

            super().__init__(vllm_config, executor_class, log_stats,
                             executor_fail_callback)

Rui Qiao's avatar
Rui Qiao committed
463
464
        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)
465

466
467
468
469
470
471
472
473
474
475
476
477
        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
        threading.Thread(target=self.process_input_sockets,
                         args=(addresses.inputs, addresses.coordinator_input,
                               identity),
                         daemon=True).start()
        self.output_thread = threading.Thread(
            target=self.process_output_sockets,
            args=(addresses.outputs, addresses.coordinator_output,
Rui Qiao's avatar
Rui Qiao committed
478
                  self.engine_index),
479
480
            daemon=True)
        self.output_thread.start()
481

Rui Qiao's avatar
Rui Qiao committed
482
    @contextmanager
483
484
485
486
487
488
489
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
        client_handshake_address: Optional[str],
Rui Qiao's avatar
Rui Qiao committed
490
    ) -> Generator[EngineZmqAddresses, None, None]:
491
492
493
494
495
496
497
498
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

        For DP>1 with internal loadbalancing this is with the shared front-end
        process which may reside on a different node.

499
500
        For DP>1 with external or hybrid loadbalancing, two handshakes are
        performed:
501
502
503
504
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
505
506
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
507
508
509
510
511
512

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
513
        input_ctx = zmq.Context()
514
        is_local = local_client and client_handshake_address is None
515
        headless = not local_client
516
        handshake = self._perform_handshake(input_ctx, handshake_address,
517
518
                                            identity, is_local, headless,
                                            vllm_config,
519
520
521
522
523
                                            vllm_config.parallel_config)
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
524
            assert local_client
525
            local_handshake = self._perform_handshake(
526
                input_ctx, client_handshake_address, identity, True, False,
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
                vllm_config)
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
543
        headless: bool,
544
545
546
547
        vllm_config: VllmConfig,
        parallel_config_to_update: Optional[ParallelConfig] = None,
    ) -> Generator[EngineZmqAddresses, None, None]:
        with make_zmq_socket(ctx,
Rui Qiao's avatar
Rui Qiao committed
548
549
550
551
552
553
                             handshake_address,
                             zmq.DEALER,
                             identity=identity,
                             linger=5000,
                             bind=False) as handshake_socket:
            # Register engine with front-end.
554
            addresses = self.startup_handshake(handshake_socket, local_client,
555
                                               headless,
556
                                               parallel_config_to_update)
Rui Qiao's avatar
Rui Qiao committed
557
558
559
560
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
561
562
563
564
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
Rui Qiao's avatar
Rui Qiao committed
565
566
567
            handshake_socket.send(
                msgspec.msgpack.encode({
                    "status": "READY",
568
                    "local": local_client,
569
                    "headless": headless,
Rui Qiao's avatar
Rui Qiao committed
570
                    "num_gpu_blocks": num_gpu_blocks,
571
                    "dp_stats_address": dp_stats_address,
Rui Qiao's avatar
Rui Qiao committed
572
573
                }))

574
    @staticmethod
575
    def startup_handshake(
576
577
        handshake_socket: zmq.Socket,
        local_client: bool,
578
        headless: bool,
579
580
        parallel_config: Optional[ParallelConfig] = None,
    ) -> EngineZmqAddresses:
581
582

        # Send registration message.
583
        handshake_socket.send(
584
585
            msgspec.msgpack.encode({
                "status": "HELLO",
586
                "local": local_client,
587
                "headless": headless,
588
589
590
591
            }))

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
592
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
593
594
595
            raise RuntimeError("Did not receive response from front-end "
                               f"process within {HANDSHAKE_TIMEOUT_MINS} "
                               f"minutes")
596
597
598
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
            init_bytes, type=EngineHandshakeMetadata)
599
600
        logger.debug("Received init message: %s", init_message)

601
602
603
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
604

605
        return init_message.addresses
606
607

    @staticmethod
608
609
610
611
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        **kwargs):
612
613
        """Launch EngineCore busy loop in background process."""

614
615
616
617
618
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

619
620
621
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

622
623
624
625
626
627
628
629
630
631
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

632
        engine_core: Optional[EngineCoreProc] = None
633
        try:
634
635
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
636
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
637
                set_process_title("DPEngineCore", str(dp_rank))
638
639
640
641
642
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
643
                set_process_title("EngineCore")
644
645
                engine_core = EngineCoreProc(*args, **kwargs)

646
647
            engine_core.run_busy_loop()

648
        except SystemExit:
649
            logger.debug("EngineCore exiting.")
650
            raise
651
652
653
654
655
656
657
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
658
659
660
661
        finally:
            if engine_core is not None:
                engine_core.shutdown()

662
663
664
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

665
666
667
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

668
669
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
670
            # 1) Poll the input queue until there is work to do.
671
672
673
674
675
676
677
678
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
679
        while not self.engines_running and not self.scheduler.has_requests():
680
681
682
683
684
685
686
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
687
            logger.debug("EngineCore loop active.")
688
689
690
691
692
693

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

694
    def _process_engine_step(self) -> bool:
695
696
697
        """Called only when there are unfinished local requests."""

        # Step the engine core.
698
        outputs, model_executed = self.step_fn()
699
        # Put EngineCoreOutputs into the output queue.
700
701
        for output in (outputs.items() if outputs else ()):
            self.output_queue.put_nowait(output)
702

703
704
        return model_executed

705
706
707
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
708

709
        if request_type == EngineCoreRequestType.ADD:
710
            self.add_request(request)
711
        elif request_type == EngineCoreRequestType.ABORT:
712
            self.abort_requests(request)
713
        elif request_type == EngineCoreRequestType.UTILITY:
714
            client_idx, call_id, method_name, args = request
715
716
717
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
718
719
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
720
721
722
723
724
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
725
                (client_idx, EngineCoreOutputs(utility_output=output)))
726
727
728
729
730
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
            logger.error("Unrecognized input request type encountered: %s",
                         request_type)
731
732
733
734
735
736
737
738
739
740
741
742
743
744

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
745

746
747
748
749
750
751
752
753
754
755
756
757
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
            logger.fatal("vLLM shutdown signal from EngineCore failed "
                         "to send. Please report this issue.")

758
759
760
    def process_input_sockets(self, input_addresses: list[str],
                              coord_input_address: Optional[str],
                              identity: bytes):
761
762
763
        """Input socket IO thread."""

        # Msgpack serialization decoding.
764
765
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
766

767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx,
                                    input_address,
                                    zmq.DEALER,
                                    identity=identity,
                                    bind=False))
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
                    make_zmq_socket(ctx,
                                    coord_input_address,
                                    zmq.XSUB,
                                    identity=identity,
                                    bind=False))
                # Send subscription message to coordinator.
                coord_socket.send(b'\x01')

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
                input_socket.send(b'')
                poller.register(input_socket, zmq.POLLIN)
            if coord_socket is not None:
                poller.register(coord_socket, zmq.POLLIN)
799

800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
                    type_frame, *data_frames = input_socket.recv_multipart(
                        copy=False)
                    request_type = EngineCoreRequestType(
                        bytes(type_frame.buffer))

                    # Deserialize the request data.
                    decoder = add_request_decoder if (
                        request_type
                        == EngineCoreRequestType.ADD) else generic_decoder
                    request = decoder.decode(data_frames)

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

    def process_output_sockets(self, output_paths: list[str],
                               coord_output_path: Optional[str],
                               engine_index: int):
820
821
822
        """Output socket IO thread."""

        # Msgpack serialization encoding.
823
        encoder = MsgpackEncoder()
824
825
826
827
828
829
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
830

831
832
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
833
834
835
836
837
838
839
840
841
842
843
844
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
                for output_path in output_paths
            ]
            coord_socket = stack.enter_context(
                make_zmq_socket(
                    ctx, coord_output_path, zmq.PUSH, bind=False,
                    linger=4000)) if coord_output_path is not None else None
            max_reuse_bufs = len(sockets) + 1

845
            while True:
846
847
848
849
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
850
                    break
851
852
                assert not isinstance(output, bytes)
                client_index, outputs = output
853
                outputs.engine_index = engine_index
854

855
856
857
858
859
860
861
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

862
863
864
865
866
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
867
                buffers = encoder.encode_into(outputs, buffer)
868
869
870
                tracker = sockets[client_index].send_multipart(buffers,
                                                               copy=False,
                                                               track=True)
871
872
873
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
874
875
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
876
                    reuse_buffers.append(buffer)
877
878
879
880
881
882
883
884
885


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
886
        local_client: bool,
887
        handshake_address: str,
888
889
        executor_class: type[Executor],
        log_stats: bool,
890
        client_handshake_address: Optional[str] = None,
891
    ):
Rui Qiao's avatar
Rui Qiao committed
892
        self._decorate_logs()
893

894
895
896
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
        self.counter = 0
897
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
898
        self.last_counts = (0, 0)
899
900
901

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
902
903
904
        super().__init__(vllm_config, local_client, handshake_address,
                         executor_class, log_stats, client_handshake_address,
                         dp_rank)
905

Rui Qiao's avatar
Rui Qiao committed
906
907
908
909
910
911
912
913
914
    def _decorate_logs(self):
        # Add process-specific prefix to stdout and stderr before
        # we initialize the engine.
        from multiprocessing import current_process
        process_name = current_process().name
        pid = os.getpid()
        _add_prefix(sys.stdout, process_name, pid)
        _add_prefix(sys.stderr, process_name, pid)

915
916
917
    def _init_data_parallel(self, vllm_config: VllmConfig):

        # Configure GPUs and stateless process group for data parallel.
918
        dp_rank = vllm_config.parallel_config.data_parallel_rank
919
        dp_size = vllm_config.parallel_config.data_parallel_size
920
921
922
923
924
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

925
926
927
928
929
930
931
932
933
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
            logger.debug("Setting kv_transfer_config.engine_id to %s",
                         vllm_config.kv_transfer_config.engine_id)

934
        self.dp_rank = dp_rank
935
936
937
938
939
940
941
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

942
    def add_request(self, request: EngineCoreRequest):
943
        if self.has_coordinator and request.current_wave != self.current_wave:
944
945
946
947
948
949
            if request.current_wave > self.current_wave:
                self.current_wave = request.current_wave
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
950
                    (-1, EngineCoreOutputs(start_wave=self.current_wave)))
951
952
953
954
955
956

        super().add_request(request)

    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        if request_type == EngineCoreRequestType.START_DP_WAVE:
957
958
959
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
                    new_wave >= self.current_wave):
960
961
962
963
964
965
966
967
                self.current_wave = new_wave
                if not self.engines_running:
                    logger.debug("EngineCore starting idle loop for wave %d.",
                                 new_wave)
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

968
    def _maybe_publish_request_counts(self):
969
        if not self.publish_dp_lb_stats:
970
971
972
973
974
975
976
977
978
979
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
            stats = SchedulerStats(*counts)
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(scheduler_stats=stats)))

980
981
982
983
984
985
986
987
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

988
989
            # 2) Step the engine core.
            executed = self._process_engine_step()
990
991
            self._maybe_publish_request_counts()

992
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
993
994
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
995
996
997
                    # All engines are idle.
                    continue

998
999
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1000
1001
1002
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1003
            self.engines_running = self._has_global_unfinished_reqs(
1004
1005
                local_unfinished_reqs)

1006
            if not self.engines_running:
1007
                if self.dp_rank == 0 or not self.has_coordinator:
1008
1009
1010
                    # Notify client that we are pausing the loop.
                    logger.debug("Wave %d finished, pausing engine loop.",
                                 self.current_wave)
1011
1012
1013
1014
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1015
                    self.output_queue.put_nowait(
1016
                        (client_index,
1017
                         EngineCoreOutputs(wave_complete=self.current_wave)))
1018
                self.current_wave += 1
1019
1020
1021

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

1022
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1023
        self.counter += 1
1024
        if self.counter != 32:
1025
1026
1027
1028
1029
            return True
        self.counter = 0

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1030

1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
    def reinitialize_distributed(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
        parallel_config.data_parallel_size = \
            reconfig_request.new_data_parallel_size
        if reconfig_request.new_data_parallel_rank != -1:
            parallel_config.data_parallel_rank = \
                reconfig_request.new_data_parallel_rank
        # local rank specifies device visibility, it should not be changed
        assert reconfig_request.new_data_parallel_rank_local == \
            ReconfigureRankType.KEEP_CURRENT_RANK
        parallel_config.data_parallel_master_ip = \
            reconfig_request.new_data_parallel_master_ip
        parallel_config.data_parallel_master_port = \
            reconfig_request.new_data_parallel_master_port
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
        reconfig_request.new_data_parallel_master_port = \
            parallel_config.data_parallel_master_port

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
                self.dp_group, self.available_gpu_memory_for_kv_cache)
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
        if reconfig_request.new_data_parallel_rank == \
        ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
            logger.info("Distributed environment reinitialized for DP rank %s",
                        self.dp_rank)

Rui Qiao's avatar
Rui Qiao committed
1075
1076
1077
1078
1079
1080
1081
1082
1083

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1084
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
        vllm_config.parallel_config.data_parallel_rank_local = \
            local_dp_rank

1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1106
1107
1108
1109
1110
1111
1112
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1113
        self._set_cuda_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1114

1115
        super().__init__(vllm_config, local_client, "", executor_class,
Rui Qiao's avatar
Rui Qiao committed
1116
1117
                         log_stats)

1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
    def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
                                  local_dp_rank: int):
        from vllm.platforms import current_platform
        device_control_env_var = current_platform.device_control_env_var
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
            os.environ[device_control_env_var] = ",".join(
                str(current_platform.device_id_to_physical_device_id(i))
                for i in range(local_dp_rank *
                               world_size, (local_dp_rank + 1) * world_size))
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
                f"base value: \"{os.getenv(device_control_env_var)}\"") from e

Rui Qiao's avatar
Rui Qiao committed
1136
1137
1138
1139
    def _decorate_logs(self):
        pass

    @contextmanager
1140
1141
1142
    def _perform_handshakes(self, handshake_address: str, identity: bytes,
                            local_client: bool, vllm_config: VllmConfig,
                            client_handshake_address: Optional[str]):
Rui Qiao's avatar
Rui Qiao committed
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()