core.py 49.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import deque
Rui Qiao's avatar
Rui Qiao committed
9
from collections.abc import Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, Callable, Optional, TypeVar, Union
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.logger import init_logger
22
from vllm.logging_utils.dump_input import dump_engine_exception
23
from vllm.lora.request import LoRARequest
24
from vllm.tasks import POOLING_TASKS, SupportedTask
25
26
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
27
28
from vllm.utils import (decorate_logs, make_zmq_socket,
                        resolve_obj_by_qualname, set_process_title)
29
30
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
31
from vllm.v1.core.sched.interface import SchedulerInterface
32
33
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
34
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
35
36
                            EngineCoreRequestType,
                            ReconfigureDistributedRequest, ReconfigureRankType,
37
                            UtilityOutput, UtilityResult)
38
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
39
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
40
from vllm.v1.executor.abstract import Executor
41
from vllm.v1.kv_cache_interface import KVCacheConfig
42
from vllm.v1.metrics.stats import SchedulerStats
43
from vllm.v1.outputs import ModelRunnerOutput
44
from vllm.v1.request import Request, RequestStatus
45
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
46
from vllm.v1.structured_output import StructuredOutputManager
47
48
49
50
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

51
POLLING_TIMEOUT_S = 2.5
52
HANDSHAKE_TIMEOUT_MINS = 5
53

54
55
_R = TypeVar('_R')  # Return type for collective_rpc

56
57
58
59

class EngineCore:
    """Inner loop of vLLM's Engine."""

60
61
62
63
64
    def __init__(self,
                 vllm_config: VllmConfig,
                 executor_class: type[Executor],
                 log_stats: bool,
                 executor_fail_callback: Optional[Callable] = None):
65

66
67
68
69
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
        load_general_plugins()

70
        self.vllm_config = vllm_config
71
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
72
73
                    VLLM_VERSION, vllm_config)

74
75
        self.log_stats = log_stats

76
77
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
78
79
80
        if executor_fail_callback is not None:
            self.model_executor.register_failure_callback(
                executor_fail_callback)
81

82
83
        self.available_gpu_memory_for_kv_cache = -1

84
        # Setup KV Caches and update CacheConfig after profiling.
85
86
87
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

88
89
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
90
91
        self.collective_rpc("initialize_cache",
                            args=(num_gpu_blocks, num_cpu_blocks))
92

93
94
        self.structured_output_manager = StructuredOutputManager(vllm_config)

95
        # Setup scheduler.
96
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
97
98
99
100
101
102
103
104
105
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
106
107
108
109
110
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
111

112
113
114
115
116
117
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

118
        self.scheduler: SchedulerInterface = Scheduler(
119
            vllm_config=vllm_config,
120
121
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
122
123
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
124
            log_stats=self.log_stats,
125
        )
126

127
        # Setup MM Input Mapper.
128
        self.mm_input_cache_server = MirroredProcessingCache(
129
            vllm_config.model_config)
130

131
132
133
134
135
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
136
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
137
138
139
140
141
142
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

143
144
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
145
        start = time.time()
146

147
        # Get all kv cache needed by the model
148
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
149

150
151
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
                self.available_gpu_memory_for_kv_cache = \
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
                available_gpu_memory = [
                    self.available_gpu_memory_for_kv_cache
                ] * len(kv_cache_specs)
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
                available_gpu_memory = (
                    self.model_executor.determine_available_memory())
                self.available_gpu_memory_for_kv_cache = \
                    available_gpu_memory[0]
167
168
169
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
170

171
        assert len(kv_cache_specs) == len(available_gpu_memory)
172
        # Get the kv cache tensor size
173
174
175
176
177
178
179
180
181
182
183
184
185
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
186
        # an arbitrary one to initialize the scheduler.
187
188
189
190
191
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
192
        num_cpu_blocks = 0
193
        scheduler_kv_cache_config = kv_cache_configs[0]
194
195

        # Initialize kv cache and warmup the execution
196
        self.model_executor.initialize_from_config(kv_cache_configs)
197

198
199
200
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
201
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
202

203
204
205
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

206
207
208
209
210
211
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
        
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
212
213
214
215
216
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request.request_id)}")

217
        if pooling_params := request.pooling_params:
218
219
220
221
222
            supported_pooling_tasks = [
                task for task in self.get_supported_tasks()
                if task in POOLING_TASKS
            ]

223
224
225
            if pooling_params.task not in supported_pooling_tasks:
                raise ValueError(f"Unsupported task: {pooling_params.task!r} "
                                 f"Supported tasks: {supported_pooling_tasks}")
226

227
        if request.kv_transfer_params is not None and (
228
229
230
                not self.scheduler.get_kv_connector()):
            logger.warning("Got kv_transfer_params, but no KVConnector found. "
                           "Disabling KVTransfer for this request.")
Robert Shaw's avatar
Robert Shaw committed
231

232
        self.scheduler.add_request(request)
233

234
    def abort_requests(self, request_ids: list[str]):
235
236
237
238
239
240
241
242
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

243
244
245
246
247
248
    def execute_model_with_error_logging(
        self,
        model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
        scheduler_output: SchedulerOutput,
    ) -> ModelRunnerOutput:
        """Execute the model and log detailed info on failure."""
249
        try:
250
            return model_fn(scheduler_output)
251
252
253
254
255
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

256
257
258
259
260
            # NOTE: This method is exception-free
            dump_engine_exception(self.vllm_config, scheduler_output,
                                  self.scheduler.make_stats())
            raise err

261
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
262
263
264
265
266
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
267

268
269
270
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
271
            return {}, False
272
        scheduler_output = self.scheduler.schedule()
273
274
275
        model_output = self.execute_model_with_error_logging(
            self.model_executor.execute_model,  # type: ignore
            scheduler_output)
276
        engine_core_outputs = self.scheduler.update_from_output(
277
            scheduler_output, model_output)  # type: ignore
278

279
280
        return (engine_core_outputs,
                scheduler_output.total_num_scheduled_tokens > 0)
281

282
    def step_with_batch_queue(
283
            self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
284
285
286
287
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
288
289
290
291
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
292
293
294
295
296
297
298
299
300
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
301
302
303
304
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
        if not self.batch_queue.full():
305
306
307
308
309
310
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

311
312
313
314
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
315
        # block until the first batch in the job queue is finished.
316
317
318
319
        # TODO(comaniac): Ideally we should peek the first batch in the
        # job queue to check if it's finished before scheduling a new batch,
        # but peeking the first element in a queue is not thread-safe,
        # so we need more work.
320
321
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
322

323
            # Blocking until the first result is available.
324
325
326
            model_output = self.execute_model_with_error_logging(
                lambda _: future.result(), scheduler_output)

327
            self.batch_queue.task_done()
328
329
            engine_core_outputs = (self.scheduler.update_from_output(
                scheduler_output, model_output))
330

331
        return engine_core_outputs, scheduled_batch
332

333
    def shutdown(self):
334
        self.structured_output_manager.clear_backend()
335
336
        if self.model_executor:
            self.model_executor.shutdown()
337
338
        if self.scheduler:
            self.scheduler.shutdown()
339

340
    def profile(self, is_start: bool = True):
341
        self.model_executor.profile(is_start)
342

343
344
345
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
        # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
346
        if self.scheduler.has_unfinished_requests():
347
348
349
350
351
            logger.warning("Resetting the multi-modal cache when requests are "
                           "in progress may lead to desynced internal caches.")

        self.mm_input_cache_server.reset()

352
353
354
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

355
356
357
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

358
359
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
360

361
362
363
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

364
365
366
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

367
368
369
370
371
372
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

373
    def list_loras(self) -> set[int]:
374
375
376
377
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
378

379
380
381
382
383
384
385
386
387
388
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_executor.save_sharded_state(path=path,
                                               pattern=pattern,
                                               max_size=max_size)

389
390
391
392
393
394
395
396
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

397
398
399
400
401
402
403
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
            tensorizer_config=tensorizer_config, )

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    def preprocess_add_request(
            self, request: EngineCoreRequest) -> tuple[Request, int]:
        """Preprocess the request.
        
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
        if request.mm_hashes is not None:
            assert request.mm_inputs is not None
            # Note on thread safety: no race condition.
            # `mm_input_cache_server` is reset at the end of LLMEngine init,
            # and will only accessed in the input processing thread afterwards.
            request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
                request.mm_inputs, request.mm_hashes)

        req = Request.from_engine_core_request(request)
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

429
430
431
432

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

433
434
    ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'

435
436
    def __init__(
        self,
437
        vllm_config: VllmConfig,
438
        local_client: bool,
439
        handshake_address: str,
440
        executor_class: type[Executor],
441
        log_stats: bool,
442
        client_handshake_address: Optional[str] = None,
443
        engine_index: int = 0,
444
    ):
Rui Qiao's avatar
Rui Qiao committed
445
446
447
448
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
        self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
                                              bytes]]()
        executor_fail_callback = lambda: self.input_queue.put_nowait(
449
450
            (EngineCoreRequestType.EXECUTOR_FAILED, b''))

Rui Qiao's avatar
Rui Qiao committed
451
452
453
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
454

455
456
457
        with self._perform_handshakes(handshake_address, identity,
                                      local_client, vllm_config,
                                      client_handshake_address) as addresses:
458
            self.client_count = len(addresses.outputs)
459
460

            # Set up data parallel environment.
461
            self.has_coordinator = addresses.coordinator_output is not None
462
463
464
465
466
467
468
469
            self.frontend_stats_publish_address = (
                addresses.frontend_stats_publish_address)
            # Only publish request queue stats to coordinator for "internal"
            # LB mode.
            self.publish_dp_lb_stats = (
                self.has_coordinator
                and not vllm_config.parallel_config.data_parallel_external_lb)

470
471
472
473
474
            self._init_data_parallel(vllm_config)

            super().__init__(vllm_config, executor_class, log_stats,
                             executor_fail_callback)

Rui Qiao's avatar
Rui Qiao committed
475
476
        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)
477

478
479
480
481
482
483
484
485
486
487
488
489
        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
        threading.Thread(target=self.process_input_sockets,
                         args=(addresses.inputs, addresses.coordinator_input,
                               identity),
                         daemon=True).start()
        self.output_thread = threading.Thread(
            target=self.process_output_sockets,
            args=(addresses.outputs, addresses.coordinator_output,
Rui Qiao's avatar
Rui Qiao committed
490
                  self.engine_index),
491
492
            daemon=True)
        self.output_thread.start()
493

Rui Qiao's avatar
Rui Qiao committed
494
    @contextmanager
495
496
497
498
499
500
501
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
        client_handshake_address: Optional[str],
Rui Qiao's avatar
Rui Qiao committed
502
    ) -> Generator[EngineZmqAddresses, None, None]:
503
504
505
506
507
508
509
510
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

        For DP>1 with internal loadbalancing this is with the shared front-end
        process which may reside on a different node.

511
512
        For DP>1 with external or hybrid loadbalancing, two handshakes are
        performed:
513
514
515
516
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
517
518
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
519
520
521
522
523
524

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
525
        input_ctx = zmq.Context()
526
        is_local = local_client and client_handshake_address is None
527
        headless = not local_client
528
        handshake = self._perform_handshake(input_ctx, handshake_address,
529
530
                                            identity, is_local, headless,
                                            vllm_config,
531
532
533
534
535
                                            vllm_config.parallel_config)
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
536
            assert local_client
537
            local_handshake = self._perform_handshake(
538
                input_ctx, client_handshake_address, identity, True, False,
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
                vllm_config)
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
555
        headless: bool,
556
557
558
559
        vllm_config: VllmConfig,
        parallel_config_to_update: Optional[ParallelConfig] = None,
    ) -> Generator[EngineZmqAddresses, None, None]:
        with make_zmq_socket(ctx,
Rui Qiao's avatar
Rui Qiao committed
560
561
562
563
564
565
                             handshake_address,
                             zmq.DEALER,
                             identity=identity,
                             linger=5000,
                             bind=False) as handshake_socket:
            # Register engine with front-end.
566
            addresses = self.startup_handshake(handshake_socket, local_client,
567
                                               headless,
568
                                               parallel_config_to_update)
Rui Qiao's avatar
Rui Qiao committed
569
570
571
572
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
573
574
575
576
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
Rui Qiao's avatar
Rui Qiao committed
577
578
579
            handshake_socket.send(
                msgspec.msgpack.encode({
                    "status": "READY",
580
                    "local": local_client,
581
                    "headless": headless,
Rui Qiao's avatar
Rui Qiao committed
582
                    "num_gpu_blocks": num_gpu_blocks,
583
                    "dp_stats_address": dp_stats_address,
Rui Qiao's avatar
Rui Qiao committed
584
585
                }))

586
    @staticmethod
587
    def startup_handshake(
588
589
        handshake_socket: zmq.Socket,
        local_client: bool,
590
        headless: bool,
591
592
        parallel_config: Optional[ParallelConfig] = None,
    ) -> EngineZmqAddresses:
593
594

        # Send registration message.
595
        handshake_socket.send(
596
597
            msgspec.msgpack.encode({
                "status": "HELLO",
598
                "local": local_client,
599
                "headless": headless,
600
601
602
603
            }))

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
604
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
605
606
607
            raise RuntimeError("Did not receive response from front-end "
                               f"process within {HANDSHAKE_TIMEOUT_MINS} "
                               f"minutes")
608
609
610
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
            init_bytes, type=EngineHandshakeMetadata)
611
612
        logger.debug("Received init message: %s", init_message)

613
614
615
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
616

617
        return init_message.addresses
618
619

    @staticmethod
620
621
622
623
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        **kwargs):
624
625
        """Launch EngineCore busy loop in background process."""

626
627
628
629
630
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

631
632
633
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

634
635
636
637
638
639
640
641
642
643
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

644
        engine_core: Optional[EngineCoreProc] = None
645
        try:
646
647
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
648
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
649
                set_process_title("DPEngineCore", str(dp_rank))
650
                decorate_logs()
651
652
653
654
655
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
656
                set_process_title("EngineCore")
657
                decorate_logs()
658
659
                engine_core = EngineCoreProc(*args, **kwargs)

660
661
            engine_core.run_busy_loop()

662
        except SystemExit:
663
            logger.debug("EngineCore exiting.")
664
            raise
665
666
667
668
669
670
671
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
672
673
674
675
        finally:
            if engine_core is not None:
                engine_core.shutdown()

676
677
678
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

679
680
681
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

682
683
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
684
            # 1) Poll the input queue until there is work to do.
685
686
687
688
689
690
691
692
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
693
        while not self.engines_running and not self.scheduler.has_requests():
694
695
696
697
698
699
700
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
701
            logger.debug("EngineCore loop active.")
702
703
704
705
706
707

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

708
    def _process_engine_step(self) -> bool:
709
710
711
        """Called only when there are unfinished local requests."""

        # Step the engine core.
712
        outputs, model_executed = self.step_fn()
713
        # Put EngineCoreOutputs into the output queue.
714
715
        for output in (outputs.items() if outputs else ()):
            self.output_queue.put_nowait(output)
716

717
718
        return model_executed

719
720
721
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
722

723
        if request_type == EngineCoreRequestType.ADD:
724
725
            req, request_wave = request
            self.add_request(req, request_wave)
726
        elif request_type == EngineCoreRequestType.ABORT:
727
            self.abort_requests(request)
728
        elif request_type == EngineCoreRequestType.UTILITY:
729
            client_idx, call_id, method_name, args = request
730
731
732
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
733
734
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
735
736
737
738
739
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
740
                (client_idx, EngineCoreOutputs(utility_output=output)))
741
742
743
744
745
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
            logger.error("Unrecognized input request type encountered: %s",
                         request_type)
746
747
748
749
750
751
752
753
754
755
756
757
758
759

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
760

761
762
763
764
765
766
767
768
769
770
771
772
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
            logger.fatal("vLLM shutdown signal from EngineCore failed "
                         "to send. Please report this issue.")

773
774
775
    def process_input_sockets(self, input_addresses: list[str],
                              coord_input_address: Optional[str],
                              identity: bytes):
776
777
778
        """Input socket IO thread."""

        # Msgpack serialization decoding.
779
780
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
781

782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx,
                                    input_address,
                                    zmq.DEALER,
                                    identity=identity,
                                    bind=False))
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
                    make_zmq_socket(ctx,
                                    coord_input_address,
                                    zmq.XSUB,
                                    identity=identity,
                                    bind=False))
                # Send subscription message to coordinator.
                coord_socket.send(b'\x01')

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
                input_socket.send(b'')
                poller.register(input_socket, zmq.POLLIN)
            if coord_socket is not None:
                poller.register(coord_socket, zmq.POLLIN)
814

815
816
817
818
819
820
821
822
823
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
                    type_frame, *data_frames = input_socket.recv_multipart(
                        copy=False)
                    request_type = EngineCoreRequestType(
                        bytes(type_frame.buffer))

                    # Deserialize the request data.
824
825
826
827
828
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
829
830
831
832
833
834
835

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

    def process_output_sockets(self, output_paths: list[str],
                               coord_output_path: Optional[str],
                               engine_index: int):
836
837
838
        """Output socket IO thread."""

        # Msgpack serialization encoding.
839
        encoder = MsgpackEncoder()
840
841
842
843
844
845
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
846

847
848
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
849
850
851
852
853
854
855
856
857
858
859
860
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
                for output_path in output_paths
            ]
            coord_socket = stack.enter_context(
                make_zmq_socket(
                    ctx, coord_output_path, zmq.PUSH, bind=False,
                    linger=4000)) if coord_output_path is not None else None
            max_reuse_bufs = len(sockets) + 1

861
            while True:
862
863
864
865
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
866
                    break
867
868
                assert not isinstance(output, bytes)
                client_index, outputs = output
869
                outputs.engine_index = engine_index
870

871
872
873
874
875
876
877
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

878
879
880
881
882
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
883
                buffers = encoder.encode_into(outputs, buffer)
884
885
886
                tracker = sockets[client_index].send_multipart(buffers,
                                                               copy=False,
                                                               track=True)
887
888
889
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
890
891
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
892
                    reuse_buffers.append(buffer)
893
894
895
896
897
898
899
900
901


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
902
        local_client: bool,
903
        handshake_address: str,
904
905
        executor_class: type[Executor],
        log_stats: bool,
906
        client_handshake_address: Optional[str] = None,
907
    ):
908
909
910
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
        self.counter = 0
911
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
912
        self.last_counts = (0, 0)
913
914
915

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
916
917
918
        super().__init__(vllm_config, local_client, handshake_address,
                         executor_class, log_stats, client_handshake_address,
                         dp_rank)
919
920
921
922

    def _init_data_parallel(self, vllm_config: VllmConfig):

        # Configure GPUs and stateless process group for data parallel.
923
        dp_rank = vllm_config.parallel_config.data_parallel_rank
924
        dp_size = vllm_config.parallel_config.data_parallel_size
925
926
927
928
929
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

930
931
932
933
934
935
936
937
938
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
            logger.debug("Setting kv_transfer_config.engine_id to %s",
                         vllm_config.kv_transfer_config.engine_id)

939
        self.dp_rank = dp_rank
940
941
942
943
944
945
946
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

947
948
949
950
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
951
952
953
954
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
955
                    (-1, EngineCoreOutputs(start_wave=self.current_wave)))
956

957
        super().add_request(request, request_wave)
958
959
960
961

    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        if request_type == EngineCoreRequestType.START_DP_WAVE:
962
963
964
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
                    new_wave >= self.current_wave):
965
966
967
968
969
970
971
972
                self.current_wave = new_wave
                if not self.engines_running:
                    logger.debug("EngineCore starting idle loop for wave %d.",
                                 new_wave)
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

973
    def _maybe_publish_request_counts(self):
974
        if not self.publish_dp_lb_stats:
975
976
977
978
979
980
981
982
983
984
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
            stats = SchedulerStats(*counts)
            self.output_queue.put_nowait(
                (-1, EngineCoreOutputs(scheduler_stats=stats)))

985
986
987
988
989
990
991
992
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

993
994
            # 2) Step the engine core.
            executed = self._process_engine_step()
995
996
            self._maybe_publish_request_counts()

997
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
998
999
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1000
1001
1002
                    # All engines are idle.
                    continue

1003
1004
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1005
1006
1007
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1008
            self.engines_running = self._has_global_unfinished_reqs(
1009
1010
                local_unfinished_reqs)

1011
            if not self.engines_running:
1012
                if self.dp_rank == 0 or not self.has_coordinator:
1013
1014
1015
                    # Notify client that we are pausing the loop.
                    logger.debug("Wave %d finished, pausing engine loop.",
                                 self.current_wave)
1016
1017
1018
1019
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1020
                    self.output_queue.put_nowait(
1021
                        (client_index,
1022
                         EngineCoreOutputs(wave_complete=self.current_wave)))
1023
                self.current_wave += 1
1024
1025
1026

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

1027
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1028
        self.counter += 1
1029
        if self.counter != 32:
1030
1031
1032
1033
1034
            return True
        self.counter = 0

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1035

1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
    def reinitialize_distributed(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
        parallel_config.data_parallel_size = \
            reconfig_request.new_data_parallel_size
        if reconfig_request.new_data_parallel_rank != -1:
            parallel_config.data_parallel_rank = \
                reconfig_request.new_data_parallel_rank
        # local rank specifies device visibility, it should not be changed
        assert reconfig_request.new_data_parallel_rank_local == \
            ReconfigureRankType.KEEP_CURRENT_RANK
        parallel_config.data_parallel_master_ip = \
            reconfig_request.new_data_parallel_master_ip
        parallel_config.data_parallel_master_port = \
            reconfig_request.new_data_parallel_master_port
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
        reconfig_request.new_data_parallel_master_port = \
            parallel_config.data_parallel_master_port

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
                self.dp_group, self.available_gpu_memory_for_kv_cache)
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
        if reconfig_request.new_data_parallel_rank == \
        ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
            logger.info("Distributed environment reinitialized for DP rank %s",
                        self.dp_rank)

Rui Qiao's avatar
Rui Qiao committed
1080
1081
1082
1083
1084
1085
1086
1087
1088

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1089
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
        vllm_config.parallel_config.data_parallel_rank_local = \
            local_dp_rank

1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1111
1112
1113
1114
1115
1116
1117
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1118
        self._set_cuda_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1119

1120
        super().__init__(vllm_config, local_client, "", executor_class,
Rui Qiao's avatar
Rui Qiao committed
1121
1122
                         log_stats)

1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
    def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
                                  local_dp_rank: int):
        from vllm.platforms import current_platform
        device_control_env_var = current_platform.device_control_env_var
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
            os.environ[device_control_env_var] = ",".join(
                str(current_platform.device_id_to_physical_device_id(i))
                for i in range(local_dp_rank *
                               world_size, (local_dp_rank + 1) * world_size))
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
                f"base value: \"{os.getenv(device_control_env_var)}\"") from e

Rui Qiao's avatar
Rui Qiao committed
1141
    @contextmanager
1142
1143
1144
    def _perform_handshakes(self, handshake_address: str, identity: bytes,
                            local_client: bool, vllm_config: VllmConfig,
                            client_handshake_address: Optional[str]):
Rui Qiao's avatar
Rui Qiao committed
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()