core.py 63.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
import queue
5
import signal
6
7
import threading
import time
8
from collections import deque
9
from collections.abc import Callable, Generator
10
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
11
from contextlib import ExitStack, contextmanager
12
from inspect import isclass, signature
13
from logging import DEBUG
14
from typing import Any, TypeVar, cast
15

16
import msgspec
17
18
import zmq

19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
21
from vllm.envs import enable_envs_cache
22
from vllm.logger import init_logger
23
from vllm.logging_utils.dump_input import dump_engine_exception
24
from vllm.lora.request import LoRARequest
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.tasks import POOLING_TASKS, SupportedTask
27
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
28
29
30
31
from vllm.utils.gc_utils import (
    freeze_gc_heap,
    maybe_attach_gc_debug_callback,
)
32
from vllm.utils.hashing import get_hash_fn_by_name
33
from vllm.utils.network_utils import make_zmq_socket
34
from vllm.utils.system_utils import decorate_logs, set_process_title
35
36
37
38
39
40
41
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
42
from vllm.v1.core.sched.interface import SchedulerInterface
43
from vllm.v1.core.sched.output import SchedulerOutput
44
from vllm.v1.engine import (
45
    EngineCoreOutput,
46
47
48
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
49
    FinishReason,
50
51
52
53
54
55
56
57
58
59
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
60
from vllm.v1.executor import Executor
61
from vllm.v1.kv_cache_interface import KVCacheConfig
62
from vllm.v1.metrics.stats import SchedulerStats
63
from vllm.v1.outputs import ModelRunnerOutput
64
from vllm.v1.request import Request, RequestStatus
65
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
66
from vllm.v1.structured_output import StructuredOutputManager
67
from vllm.v1.utils import compute_iteration_details
68
69
70
71
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

72
POLLING_TIMEOUT_S = 2.5
73
HANDSHAKE_TIMEOUT_MINS = 5
74

75
_R = TypeVar("_R")  # Return type for collective_rpc
76

77
78
79
80

class EngineCore:
    """Inner loop of vLLM's Engine."""

81
82
83
84
85
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
86
        executor_fail_callback: Callable | None = None,
87
        include_finished_set: bool = False,
88
    ):
89
90
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
91

92
93
        load_general_plugins()

94
        self.vllm_config = vllm_config
95
        if not vllm_config.parallel_config.data_parallel_rank_local:
96
97
98
99
100
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
101

102
103
        self.log_stats = log_stats

104
105
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
106
        if executor_fail_callback is not None:
107
            self.model_executor.register_failure_callback(executor_fail_callback)
108

109
110
        self.available_gpu_memory_for_kv_cache = -1

111
        # Setup KV Caches and update CacheConfig after profiling.
112
113
114
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
115

116
117
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
118
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
119

120
121
        self.structured_output_manager = StructuredOutputManager(vllm_config)

122
        # Setup scheduler.
123
        Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
124

125
        if len(kv_cache_config.kv_cache_groups) == 0:  # noqa: SIM102
126
127
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
128
129
130
            if vllm_config.scheduler_config.enable_chunked_prefill:
                logger.warning("Disabling chunked prefill for model without KVCache")
                vllm_config.scheduler_config.enable_chunked_prefill = False
131

132
133
134
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
135
            * vllm_config.parallel_config.prefill_context_parallel_size
136
137
        )

138
        self.scheduler: SchedulerInterface = Scheduler(
139
            vllm_config=vllm_config,
140
141
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
142
            include_finished_set=include_finished_set,
143
            log_stats=self.log_stats,
144
            block_size=scheduler_block_size,
145
        )
146
        self.use_spec_decode = vllm_config.speculative_config is not None
147
        if self.scheduler.connector is not None:  # type: ignore
148
            self.model_executor.init_kv_output_aggregator(self.scheduler.connector)  # type: ignore
149

150
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
151
152
        self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
            vllm_config
153
        )
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        # If a KV connector is initialized for scheduler, we want to collect
        # handshake metadata from all workers so the connector in the scheduler
        # will have the full context
        kv_connector = self.scheduler.get_kv_connector()
        if kv_connector is not None:
            # Collect and store KV connector xfer metadata from workers
            # (after KV cache registration)
            xfer_handshake_metadata = (
                self.model_executor.get_kv_connector_handshake_metadata()
            )

            if xfer_handshake_metadata:
                # xfer_handshake_metadata is list of dicts from workers
                # Each dict already has structure {tp_rank: metadata}
                # Merge all worker dicts into a single dict
                content: dict[int, Any] = {}
                for worker_dict in xfer_handshake_metadata:
                    if worker_dict is not None:
                        content.update(worker_dict)
                kv_connector.set_xfer_handshake_metadata(content)

176
177
178
179
180
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
181
        self.batch_queue: (
182
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput, Future[Any]]] | None
183
        ) = None
184
        if self.batch_queue_size > 1:
185
            logger.debug("Batch queue is enabled with size %d", self.batch_queue_size)
186
            self.batch_queue = deque(maxlen=self.batch_queue_size)
187

188
        self.is_ec_producer = (
189
190
191
            vllm_config.ec_transfer_config is not None
            and vllm_config.ec_transfer_config.is_ec_producer
        )
192
        self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
193

194
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
195
        if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
196
            caching_hash_fn = get_hash_fn_by_name(
197
198
                vllm_config.cache_config.prefix_caching_hash_algo
            )
199
200
201
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
202
                scheduler_block_size, caching_hash_fn
203
            )
204

205
206
207
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
208
        self.async_scheduling = vllm_config.scheduler_config.async_scheduling
209

210
        self.aborts_queue = queue.Queue[list[str]]()
211
212
213
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        freeze_gc_heap()
214
215
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()
216
217
218
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()
219

220
    def _initialize_kv_caches(
221
222
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
223
        start = time.time()
224

225
        # Get all kv cache needed by the model
226
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
227

228
229
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
230
231
232
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
233
                self.available_gpu_memory_for_kv_cache = (
234
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
235
236
237
238
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
239
240
241
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
242
243
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
244
245
246
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
247

248
        assert len(kv_cache_specs) == len(available_gpu_memory)
249

250
251
252
        # Track max_model_len before KV cache config to detect auto-fit changes
        max_model_len_before = vllm_config.model_config.max_model_len

253
254
255
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
256
257
258
259
260
261
262
263

        # If auto-fit reduced max_model_len, sync the new value to workers.
        # This is needed because workers were spawned before memory profiling
        # and have the original (larger) max_model_len cached.
        max_model_len_after = vllm_config.model_config.max_model_len
        if max_model_len_after != max_model_len_before:
            self.collective_rpc("update_max_model_len", args=(max_model_len_after,))

264
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
265
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
266
        num_cpu_blocks = 0
267
268

        # Initialize kv cache and warmup the execution
269
        self.model_executor.initialize_from_config(kv_cache_configs)
270

271
        elapsed = time.time() - start
272
        logger.info_once(
273
            "init engine (profile, create kv cache, warmup model) took %.2f seconds",
274
            elapsed,
275
            scope="local",
276
        )
277
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
278

279
280
281
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

282
283
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
284

285
286
287
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
288
289
290
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
291
292
                f"request_id must be a string, got {type(request.request_id)}"
            )
293

294
        if pooling_params := request.pooling_params:
295
            supported_pooling_tasks = [
296
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
297
298
            ]

299
            if pooling_params.task not in supported_pooling_tasks:
300
301
302
303
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
304

305
        if request.kv_transfer_params is not None and (
306
307
308
309
310
311
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
312

313
        self.scheduler.add_request(request)
314

315
    def abort_requests(self, request_ids: list[str]):
316
317
318
319
320
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
321
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
322

323
324
    @contextmanager
    def log_error_detail(self, scheduler_output: SchedulerOutput):
325
        """Execute the model and log detailed info on failure."""
326
        try:
327
            yield
328
329
330
331
332
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

333
            # NOTE: This method is exception-free
334
335
336
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
337
338
            raise err

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    @contextmanager
    def log_iteration_details(self, scheduler_output: SchedulerOutput):
        if not self.vllm_config.observability_config.enable_logging_iteration_details:
            yield
            return
        self._iteration_index = getattr(self, "_iteration_index", 0)
        iteration_details = compute_iteration_details(scheduler_output)
        before = time.monotonic()
        yield
        logger.info(
            "".join(
                [
                    "Iteration(",
                    str(self._iteration_index),
                    "): ",
                    str(iteration_details.num_ctx_requests),
                    " context requests, ",
                    str(iteration_details.num_ctx_tokens),
                    " context tokens, ",
                    str(iteration_details.num_generation_requests),
                    " generation requests, ",
                    str(iteration_details.num_generation_tokens),
                    " generation tokens, iteration elapsed time: ",
                    format((time.monotonic() - before) * 1000, ".2f"),
                    " ms",
                ]
            )
        )
        self._iteration_index += 1

369
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
370
371
372
373
374
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
375

376
377
378
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
379
            return {}, False
380
381
382
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
383
384
385
386
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
387
388
389
390
            model_output = future.result()
            if model_output is None:
                model_output = self.model_executor.sample_tokens(grammar_output)

391
392
393
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
394
395
396
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
397

398
        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
399

400
    def post_step(self, model_executed: bool) -> None:
401
402
403
404
        # When using async scheduling we can't get draft token ids in advance,
        # so we update draft token ids in the worker process and don't
        # need to update draft token ids here.
        if not self.async_scheduling and self.use_spec_decode and model_executed:
405
406
407
408
409
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

410
    def step_with_batch_queue(
411
        self,
412
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
413
414
415
416
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
417
418
419
420
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
421
422
423
424
425
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
426
427
        batch_queue = self.batch_queue
        assert batch_queue is not None
428

429
430
431
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
432
        assert len(batch_queue) < self.batch_queue_size
433

434
        model_executed = False
435
        deferred_scheduler_output = None
436
        if self.scheduler.has_requests():
437
438
439
440
            scheduler_output = self.scheduler.schedule()
            exec_future = self.model_executor.execute_model(
                scheduler_output, non_block=True
            )
441
            if not self.is_ec_producer:
442
                model_executed = scheduler_output.total_num_scheduled_tokens > 0
443

444
            if self.is_pooling_model or not model_executed:
445
446
                # No sampling required (no requests scheduled).
                future = cast(Future[ModelRunnerOutput], exec_future)
447
            else:
448
449
450
                if not scheduler_output.pending_structured_output_tokens:
                    # We aren't waiting for any tokens, get any grammar output
                    # and sample immediately.
451
452
453
                    grammar_output = self.scheduler.get_grammar_bitmask(
                        scheduler_output
                    )
454
455
456
                    future = self.model_executor.sample_tokens(
                        grammar_output, non_block=True
                    )
457
                else:
458
459
460
461
462
                    # We need to defer sampling until we have processed the model output
                    # from the prior step.
                    deferred_scheduler_output = scheduler_output

            if not deferred_scheduler_output:
463
                # Add this step's future to the queue.
464
                batch_queue.appendleft((future, scheduler_output, exec_future))
465
466
467
468
469
470
471
472
                if (
                    model_executed
                    and len(batch_queue) < self.batch_queue_size
                    and not batch_queue[-1][0].done()
                ):
                    # Don't block on next worker response unless the queue is full
                    # or there are no more requests to schedule.
                    return None, True
473
474
475
476
477
478

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False
479
480

        # Block until the next result is available.
481
        future, scheduler_output, exec_model_fut = batch_queue.pop()
482
483
484
485
        with (
            self.log_error_detail(scheduler_output),
            self.log_iteration_details(scheduler_output),
        ):
486
            model_output = future.result()
487
488
489
490
491
            if model_output is None:
                # None from sample_tokens() implies that the original execute_model()
                # call failed - raise that exception.
                exec_model_fut.result()
                raise RuntimeError("unexpected error")
492

493
494
495
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
496
497
498
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )
499
500
501
502
503

        # NOTE(nick): We can either handle the deferred tasks here or save
        # in a field and do it immediately once step_with_batch_queue is
        # re-called. The latter slightly favors TTFT over TPOT/throughput.
        if deferred_scheduler_output:
504
505
506
507
508
509
510
511
512
513
514
515
            # If we are doing speculative decoding with structured output,
            # we need to get the draft token ids from the prior step before
            # we can compute the grammar bitmask for the deferred request.
            if self.use_spec_decode:
                draft_token_ids = self.model_executor.take_draft_token_ids()
                assert draft_token_ids is not None
                # Update the draft token ids in the scheduler output to
                # filter out the invalid spec tokens, which will be padded
                # with -1 and skipped by the grammar bitmask computation.
                self.scheduler.update_draft_token_ids_in_output(
                    draft_token_ids, deferred_scheduler_output
                )
516
517
518
519
520
521
            # We now have the tokens needed to compute the bitmask for the
            # deferred request. Get the bitmask and call sample tokens.
            grammar_output = self.scheduler.get_grammar_bitmask(
                deferred_scheduler_output
            )
            future = self.model_executor.sample_tokens(grammar_output, non_block=True)
522
            batch_queue.appendleft((future, deferred_scheduler_output, exec_future))
523

524
        return engine_core_outputs, model_executed
525

526
527
528
529
530
    def _process_aborts_queue(self):
        if not self.aborts_queue.empty():
            request_ids = []
            while not self.aborts_queue.empty():
                ids = self.aborts_queue.get_nowait()
531
532
                # Should be a list here, but also handle string just in case.
                request_ids.extend((ids,) if isinstance(ids, str) else ids)
533
534
535
            # More efficient to abort all as a single batch.
            self.abort_requests(request_ids)

536
    def shutdown(self):
537
        self.structured_output_manager.clear_backend()
538
539
        if self.model_executor:
            self.model_executor.shutdown()
540
541
        if self.scheduler:
            self.scheduler.shutdown()
542

543
    def profile(self, is_start: bool = True):
544
        self.model_executor.profile(is_start)
545

546
547
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
548
        # re-sync the internal caches (P0 sender, P1 receiver)
549
        if self.scheduler.has_unfinished_requests():
550
551
552
553
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
554

555
        # The cache either exists in EngineCore or WorkerWrapperBase
556
557
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
558

559
560
        self.model_executor.reset_mm_cache()

561
562
563
564
565
566
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.scheduler.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
567

568
569
570
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

571
    def wake_up(self, tags: list[str] | None = None):
572
        self.model_executor.wake_up(tags)
573

574
575
576
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

577
    def execute_dummy_batch(self):
578
        self.model_executor.execute_dummy_batch()
579

580
581
582
583
584
585
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

586
    def list_loras(self) -> set[int]:
587
588
589
590
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
591

592
593
594
    def save_sharded_state(
        self,
        path: str,
595
596
        pattern: str | None = None,
        max_size: int | None = None,
597
    ) -> None:
598
599
600
601
602
603
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
604
605
        method: str | Callable[..., _R],
        timeout: float | None = None,
606
        args: tuple = (),
607
        kwargs: dict[str, Any] | None = None,
608
609
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
610

611
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
612
        """Preprocess the request.
613

614
615
616
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
617
618
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
619
        # and will only be accessed in the input processing thread afterwards.
620
        if self.mm_receiver_cache is not None and request.mm_features:
621
622
623
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
624

625
        req = Request.from_engine_core_request(request, self.request_block_hasher)
626
627
628
629
630
631
632
633
634
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

635
636
637
638

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

639
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
640

641
642
    def __init__(
        self,
643
        vllm_config: VllmConfig,
644
        local_client: bool,
645
        handshake_address: str,
646
        executor_class: type[Executor],
647
        log_stats: bool,
648
        client_handshake_address: str | None = None,
649
        *,
650
        engine_index: int = 0,
651
    ):
Rui Qiao's avatar
Rui Qiao committed
652
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
653
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
654
        executor_fail_callback = lambda: self.input_queue.put_nowait(
655
656
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
657

Rui Qiao's avatar
Rui Qiao committed
658
659
660
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
661

662
663
664
665
666
667
668
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
669
            self.client_count = len(addresses.outputs)
670
671

            # Set up data parallel environment.
672
            self.has_coordinator = addresses.coordinator_output is not None
673
            self.frontend_stats_publish_address = (
674
675
676
677
678
679
680
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
681
            internal_dp_balancing = (
682
                self.has_coordinator
683
684
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
685
686
687
            # Only publish request queue stats to coordinator for "internal"
            # and "hybrid" LB modes.
            self.publish_dp_lb_stats = internal_dp_balancing
688

689
690
            self._init_data_parallel(vllm_config)

691
            super().__init__(
692
693
694
695
696
                vllm_config,
                executor_class,
                log_stats,
                executor_fail_callback,
                internal_dp_balancing,
697
            )
698

699
700
701
702
703
704
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
705
706
707
708
709
710
711
712
713
714
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
715
716
717
718
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
719
720
721
722
723
724
725
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
726
727
728
729
730
731
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
732
                    raise RuntimeError("Input socket thread died during startup")
733
734
735
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

Rui Qiao's avatar
Rui Qiao committed
736
    @contextmanager
737
738
739
740
741
742
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
743
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
744
    ) -> Generator[EngineZmqAddresses, None, None]:
745
746
747
748
749
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

750
        For DP>1 with internal load-balancing this is with the shared front-end
751
752
        process which may reside on a different node.

753
        For DP>1 with external or hybrid load-balancing, two handshakes are
754
        performed:
755
756
757
758
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
759
760
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
761
762
763
764
765
766

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
767
        input_ctx = zmq.Context()
768
        is_local = local_client and client_handshake_address is None
769
        headless = not local_client
770
771
772
773
774
775
776
777
778
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
779
780
781
782
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
783
            assert local_client
784
            local_handshake = self._perform_handshake(
785
786
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
802
        headless: bool,
803
        vllm_config: VllmConfig,
804
        parallel_config_to_update: ParallelConfig | None = None,
805
    ) -> Generator[EngineZmqAddresses, None, None]:
806
807
808
809
810
811
812
813
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
814
            # Register engine with front-end.
815
816
817
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
818
819
820
821
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
822
823
824
825
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
826
827
828
829
830
831
832
833
834
835
836
837

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
838
                )
839
840

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
841

842
    @staticmethod
843
    def startup_handshake(
844
845
        handshake_socket: zmq.Socket,
        local_client: bool,
846
        headless: bool,
847
        parallel_config: ParallelConfig | None = None,
848
    ) -> EngineZmqAddresses:
849
        # Send registration message.
850
        handshake_socket.send(
851
852
853
854
855
856
857
858
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
859
860

        # Receive initialization message.
861
        logger.debug("Waiting for init message from front-end.")
862
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
863
864
865
866
867
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
868
869
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
870
871
            init_bytes, type=EngineHandshakeMetadata
        )
872
873
        logger.debug("Received init message: %s", init_message)

874
875
876
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
877

878
        return init_message.addresses
879
880

    @staticmethod
881
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
882
883
        """Launch EngineCore busy loop in background process."""

884
885
886
887
888
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

889
890
891
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

892
893
894
895
896
897
898
899
900
901
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

902
        engine_core: EngineCoreProc | None = None
903
        try:
904
905
906
907
908
            vllm_config: VllmConfig = kwargs["vllm_config"]
            parallel_config: ParallelConfig = vllm_config.parallel_config
            data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
            if data_parallel:
                parallel_config.data_parallel_rank_local = local_dp_rank
909
                set_process_title("EngineCore", f"DP{dp_rank}")
910
911
912
913
            else:
                set_process_title("EngineCore")
            decorate_logs()

914
915
916
917
918
919
920
921
922
923
924
            if data_parallel and vllm_config.kv_transfer_config is not None:
                # modify the engine_id and append the local_dp_rank to it to ensure
                # that the kv_transfer_config is unique for each DP rank.
                vllm_config.kv_transfer_config.engine_id = (
                    f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
                )
                logger.debug(
                    "Setting kv_transfer_config.engine_id to %s",
                    vllm_config.kv_transfer_config.engine_id,
                )

925
926
            parallel_config.data_parallel_index = dp_rank
            if data_parallel and vllm_config.model_config.is_moe:
927
928
929
930
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
931
932
933
934
935
936
937
                # Non-MoE DP ranks are completely independent, so treat like DP=1.
                # Note that parallel_config.data_parallel_index will still reflect
                # the original DP rank.
                parallel_config.data_parallel_size = 1
                parallel_config.data_parallel_size_local = 1
                parallel_config.data_parallel_rank = 0
                engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
938

939
940
            engine_core.run_busy_loop()

941
        except SystemExit:
942
            logger.debug("EngineCore exiting.")
943
            raise
944
945
946
947
948
949
950
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
951
952
953
954
        finally:
            if engine_core is not None:
                engine_core.shutdown()

955
956
957
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

958
959
960
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

961
962
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
963
            # 1) Poll the input queue until there is work to do.
964
965
966
967
968
969
970
971
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
972
973
974
975
976
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
977
978
979
980
981
982
983
            if self.input_queue.empty():
                # Drain aborts queue; all aborts are also processed via input_queue.
                with self.aborts_queue.mutex:
                    self.aborts_queue.queue.clear()
                if logger.isEnabledFor(DEBUG):
                    logger.debug("EngineCore waiting for work.")
                    waited = True
984
985
986
987
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
988
            logger.debug("EngineCore loop active.")
989
990
991
992
993
994

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

995
    def _process_engine_step(self) -> bool:
996
997
998
        """Called only when there are unfinished local requests."""

        # Step the engine core.
999
        outputs, model_executed = self.step_fn()
1000
        # Put EngineCoreOutputs into the output queue.
1001
        for output in outputs.items() if outputs else ():
1002
            self.output_queue.put_nowait(output)
1003
1004
        # Post-step hook.
        self.post_step(model_executed)
1005

1006
1007
1008
1009
1010
1011
1012
        # If no model execution happened but there are waiting requests
        # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
        # background threads (like NIXL handshake) to make progress.
        # Without this, the tight polling loop can starve background threads.
        if not model_executed and self.scheduler.has_unfinished_requests():
            time.sleep(0.001)

1013
1014
        return model_executed

1015
1016
1017
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1018
        """Dispatch request from client."""
1019

1020
        if request_type == EngineCoreRequestType.ADD:
1021
1022
            req, request_wave = request
            self.add_request(req, request_wave)
1023
        elif request_type == EngineCoreRequestType.ABORT:
1024
            self.abort_requests(request)
1025
        elif request_type == EngineCoreRequestType.UTILITY:
1026
            client_idx, call_id, method_name, args = request
1027
1028
1029
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
1030
1031
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
1032
1033
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
1034
1035
1036
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
1037
            self.output_queue.put_nowait(
1038
1039
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
1040
1041
1042
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
1043
1044
1045
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
1046
1047
1048
1049

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
1050
        arg type, try converting to msgspec object."""
1051
1052
1053
1054
1055
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
1056
1057
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
1058
            and issubclass(p.annotation, msgspec.Struct)
1059
1060
1061
1062
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
1063

1064
1065
1066
1067
1068
1069
1070
1071
1072
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
1073
1074
1075
1076
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
1077

1078
1079
1080
    def process_input_sockets(
        self,
        input_addresses: list[str],
1081
        coord_input_address: str | None,
1082
1083
1084
        identity: bytes,
        ready_event: threading.Event,
    ):
1085
1086
1087
        """Input socket IO thread."""

        # Msgpack serialization decoding.
1088
1089
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
1090

1091
1092
1093
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
1094
1095
1096
1097
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
1098
1099
1100
1101
1102
1103
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
1104
1105
1106
1107
1108
1109
1110
1111
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
1112
                # Send subscription message to coordinator.
1113
                coord_socket.send(b"\x01")
1114
1115
1116
1117
1118
1119
1120

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
1121
                input_socket.send(b"")
1122
                poller.register(input_socket, zmq.POLLIN)
1123

1124
            if coord_socket is not None:
1125
1126
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
1127
                poller.register(coord_socket, zmq.POLLIN)
1128

1129
1130
            ready_event.set()
            del ready_event
1131
1132
1133
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
1134
1135
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
1136
1137

                    # Deserialize the request data.
1138
                    request: Any
1139
                    if request_type == EngineCoreRequestType.ADD:
1140
1141
1142
1143
1144
1145
                        req: EngineCoreRequest = add_request_decoder.decode(data_frames)
                        try:
                            request = self.preprocess_add_request(req)
                        except Exception:
                            self._handle_request_preproc_error(req)
                            continue
1146
1147
                    else:
                        request = generic_decoder.decode(data_frames)
1148

1149
1150
1151
1152
1153
1154
1155
                        if request_type == EngineCoreRequestType.ABORT:
                            # Aborts are added to *both* queues, allows us to eagerly
                            # process aborts while also ensuring ordering in the input
                            # queue to avoid leaking requests. This is ok because
                            # aborting in the scheduler is idempotent.
                            self.aborts_queue.put_nowait(request)

1156
1157
1158
                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

1159
1160
1161
    def process_output_sockets(
        self,
        output_paths: list[str],
1162
        coord_output_path: str | None,
1163
1164
        engine_index: int,
    ):
1165
1166
1167
        """Output socket IO thread."""

        # Msgpack serialization encoding.
1168
        encoder = MsgpackEncoder()
1169
1170
1171
1172
1173
1174
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1175

1176
1177
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1178
1179
1180
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1181
1182
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1183
1184
                for output_path in output_paths
            ]
1185
1186
1187
1188
1189
1190
1191
1192
1193
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1194
1195
            max_reuse_bufs = len(sockets) + 1

1196
            while True:
1197
1198
1199
1200
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1201
                    break
1202
1203
                assert not isinstance(output, bytes)
                client_index, outputs = output
1204
                outputs.engine_index = engine_index
1205

1206
1207
1208
1209
1210
1211
1212
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1213
1214
1215
1216
1217
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1218
                buffers = encoder.encode_into(outputs, buffer)
1219
1220
1221
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1222
1223
1224
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1225
1226
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1227
                    reuse_buffers.append(buffer)
1228

1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
    def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
        """Log and return a request-scoped error response for exceptions raised
        from the add request preprocessing in the input socket processing thread.
        """
        logger.exception(
            "Unexpected error pre-processing request %s", request.request_id
        )
        self.output_queue.put_nowait(
            (
                request.client_index,
                EngineCoreOutputs(
                    engine_index=self.engine_index,
                    finished_requests={request.request_id},
                    outputs=[
                        EngineCoreOutput(
                            request_id=request.request_id,
                            new_token_ids=[],
                            finish_reason=FinishReason.ERROR,
                        )
                    ],
                ),
            )
        )

1253
1254
1255
1256
1257
1258
1259
1260

class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1261
        local_client: bool,
1262
        handshake_address: str,
1263
1264
        executor_class: type[Executor],
        log_stats: bool,
1265
        client_handshake_address: str | None = None,
1266
    ):
1267
1268
1269
1270
        assert vllm_config.model_config.is_moe, (
            "DPEngineCoreProc should only be used for MoE models"
        )

1271
1272
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1273
        self.step_counter = 0
1274
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1275
        self.last_counts = (0, 0)
1276
1277
1278

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1279
1280
1281
1282
1283
1284
1285
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
1286
            engine_index=dp_rank,
1287
        )
1288
1289
1290

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1291
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1292
        dp_size = vllm_config.parallel_config.data_parallel_size
1293
1294
1295
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
1296
        assert local_dp_rank is not None
1297
1298
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1299
        self.dp_rank = dp_rank
1300
1301
1302
1303
1304
1305
1306
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1307
1308
1309
1310
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1311
1312
1313
1314
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1315
1316
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1317

1318
        super().add_request(request, request_wave)
1319

1320
1321
1322
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1323
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1324
1325
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1326
1327
                new_wave >= self.current_wave
            ):
1328
1329
                self.current_wave = new_wave
                if not self.engines_running:
1330
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1331
1332
1333
1334
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1335
    def _maybe_publish_request_counts(self):
1336
        if not self.publish_dp_lb_stats:
1337
1338
1339
1340
1341
1342
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1343
1344
1345
1346
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1347

1348
1349
1350
1351
1352
1353
1354
1355
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1356
1357
            # 2) Step the engine core.
            executed = self._process_engine_step()
1358
1359
            self._maybe_publish_request_counts()

1360
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1361
1362
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1363
1364
1365
                    # All engines are idle.
                    continue

1366
1367
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1368
1369
1370
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1371
            self.engines_running = self._has_global_unfinished_reqs(
1372
1373
                local_unfinished_reqs
            )
1374

1375
            if not self.engines_running:
1376
                if self.dp_rank == 0 or not self.has_coordinator:
1377
                    # Notify client that we are pausing the loop.
1378
1379
1380
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1381
1382
1383
1384
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1385
                    self.output_queue.put_nowait(
1386
1387
1388
1389
1390
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1391
                # Increment wave count and reset step counter.
1392
                self.current_wave += 1
1393
                self.step_counter = 0
1394
1395

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1396
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1397
1398
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1399
1400
            return True

1401
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1402

1403
    def reinitialize_distributed(
1404
1405
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1406
1407
1408
1409
1410
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1411
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1412
        if reconfig_request.new_data_parallel_rank != -1:
1413
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1414
        # local rank specifies device visibility, it should not be changed
1415
1416
1417
1418
1419
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1420
            reconfig_request.new_data_parallel_master_ip
1421
1422
        )
        parallel_config.data_parallel_master_port = (
1423
            reconfig_request.new_data_parallel_master_port
1424
        )
1425
1426
1427
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1428
        reconfig_request.new_data_parallel_master_port = (
1429
            parallel_config.data_parallel_master_port
1430
        )
1431
1432
1433
1434
1435
1436
1437
1438

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1439
1440
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1441
1442
1443
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1444
1445
1446
1447
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1448
1449
1450
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1451
1452
1453
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1454

Rui Qiao's avatar
Rui Qiao committed
1455

1456
class EngineCoreActorMixin:
Rui Qiao's avatar
Rui Qiao committed
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
1469
        vllm_config.parallel_config.data_parallel_index = dp_rank
1470
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1471

1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1482
1483
1484
1485
1486
1487
1488
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1489
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1490

1491
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1492
        from vllm.platforms import current_platform
1493

1494
1495
1496
1497
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1498
1499
1500
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1501

1502
1503
1504
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1505
1506
1507
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1508
1509
1510
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1511
            os.environ[device_control_env_var] = value
1512
1513
1514
1515
1516
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1517
1518
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1519

Rui Qiao's avatar
Rui Qiao committed
1520
    @contextmanager
1521
1522
1523
1524
1525
1526
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1527
        client_handshake_address: str | None,
1528
    ):
Rui Qiao's avatar
Rui Qiao committed
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
1551
            self.run_busy_loop()  # type: ignore[attr-defined]
Rui Qiao's avatar
Rui Qiao committed
1552
1553
1554
1555
1556
1557
1558
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
            self.shutdown()  # type: ignore[attr-defined]


class DPMoEEngineCoreActor(EngineCoreActorMixin, DPEngineCoreProc):
    """Used for MoE model data parallel cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_rank = dp_rank

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        DPEngineCoreProc.__init__(
            self, vllm_config, local_client, "", executor_class, log_stats
        )


class EngineCoreActor(EngineCoreActorMixin, EngineCoreProc):
    """Used for non-MoE and/or non-DP cases."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_client: bool,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        vllm_config.parallel_config.data_parallel_size = 1
        vllm_config.parallel_config.data_parallel_size_local = 1
        vllm_config.parallel_config.data_parallel_rank = 0

        EngineCoreActorMixin.__init__(
            self, vllm_config, addresses, dp_rank, local_dp_rank
        )
        EngineCoreProc.__init__(
            self,
            vllm_config,
            local_client,
            "",
            executor_class,
            log_stats,
            engine_index=dp_rank,
        )