core.py 52.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import gc
4
import os
5
import queue
6
import signal
7
8
import threading
import time
9
from collections import deque
10
from collections.abc import Callable, Generator
11
from concurrent.futures import Future
Rui Qiao's avatar
Rui Qiao committed
12
from contextlib import ExitStack, contextmanager
13
from inspect import isclass, signature
14
from logging import DEBUG
15
from typing import Any, TypeVar
16

17
import msgspec
18
19
import zmq

20
21
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
22
from vllm.distributed.parallel_state import is_global_first_rank
23
from vllm.envs import enable_envs_cache
24
from vllm.logger import init_logger
25
from vllm.logging_utils.dump_input import dump_engine_exception
26
from vllm.lora.request import LoRARequest
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.cache import engine_receiver_cache_from_config
29
from vllm.tasks import POOLING_TASKS, SupportedTask
30
31
32
33
34
35
36
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils import (
    decorate_logs,
    get_hash_fn_by_name,
    make_zmq_socket,
    set_process_title,
)
37
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
38
from vllm.utils.import_utils import resolve_obj_by_qualname
39
40
41
42
43
44
45
from vllm.v1.core.kv_cache_utils import (
    BlockHash,
    generate_scheduler_kv_cache_config,
    get_kv_cache_configs,
    get_request_block_hasher,
    init_none_hash,
)
46
from vllm.v1.core.sched.interface import SchedulerInterface
47
48
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from vllm.v1.engine import (
    EngineCoreOutputs,
    EngineCoreRequest,
    EngineCoreRequestType,
    ReconfigureDistributedRequest,
    ReconfigureRankType,
    UtilityOutput,
    UtilityResult,
)
from vllm.v1.engine.utils import (
    EngineHandshakeMetadata,
    EngineZmqAddresses,
    get_device_indices,
)
63
from vllm.v1.executor.abstract import Executor
64
from vllm.v1.kv_cache_interface import KVCacheConfig
65
from vllm.v1.metrics.stats import SchedulerStats
66
from vllm.v1.outputs import ModelRunnerOutput
67
from vllm.v1.request import Request, RequestStatus
68
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
69
from vllm.v1.structured_output import StructuredOutputManager
70
71
72
73
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

74
POLLING_TIMEOUT_S = 2.5
75
HANDSHAKE_TIMEOUT_MINS = 5
76

77
_R = TypeVar("_R")  # Return type for collective_rpc
78

79
80
81
82

class EngineCore:
    """Inner loop of vLLM's Engine."""

83
84
85
86
87
    def __init__(
        self,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
88
        executor_fail_callback: Callable | None = None,
89
    ):
90
91
        # plugins need to be loaded at the engine/scheduler level too
        from vllm.plugins import load_general_plugins
92

93
94
        load_general_plugins()

95
        self.vllm_config = vllm_config
96
97
98
99
100
101
        if is_global_first_rank():
            logger.info(
                "Initializing a V1 LLM engine (v%s) with config: %s",
                VLLM_VERSION,
                vllm_config,
            )
102

103
104
        self.log_stats = log_stats

105
106
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
107
        if executor_fail_callback is not None:
108
            self.model_executor.register_failure_callback(executor_fail_callback)
109

110
111
        self.available_gpu_memory_for_kv_cache = -1

112
        # Setup KV Caches and update CacheConfig after profiling.
113
114
115
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
            vllm_config
        )
116

117
118
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
119
        self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
120

121
122
        self.structured_output_manager = StructuredOutputManager(vllm_config)

123
        # Setup scheduler.
124
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
125
            Scheduler = resolve_obj_by_qualname(
126
127
                vllm_config.scheduler_config.scheduler_cls
            )
128
129
130
131
132
133
134
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
135
136
137
138
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
139
140
                vllm_config.scheduler_config.scheduler_cls,
            )
141

142
143
144
145
146
147
        if len(kv_cache_config.kv_cache_groups) == 0:
            # Encoder models without KV cache don't support
            # chunked prefill. But do SSM models?
            logger.info("Disabling chunked prefill for model without KVCache")
            vllm_config.scheduler_config.chunked_prefill_enabled = False

148
149
150
151
152
        scheduler_block_size = (
            vllm_config.cache_config.block_size
            * vllm_config.parallel_config.decode_context_parallel_size
        )

153
        self.scheduler: SchedulerInterface = Scheduler(
154
            vllm_config=vllm_config,
155
156
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
157
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
158
            log_stats=self.log_stats,
159
            block_size=scheduler_block_size,
160
        )
161
        self.use_spec_decode = vllm_config.speculative_config is not None
162
163
        if self.scheduler.connector is not None:  # type: ignore
            self.model_executor.init_kv_output_aggregator(
164
165
                self.scheduler.connector.get_finished_count()  # type: ignore
            )
166

167
        self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
168
        self.mm_receiver_cache = engine_receiver_cache_from_config(
169
170
            vllm_config, mm_registry
        )
171

172
173
174
175
176
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
177
178
179
        self.batch_queue: (
            deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
        ) = None
180
        if self.batch_queue_size > 1:
181
            logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
182
            self.batch_queue = deque(maxlen=self.batch_queue_size)
183

184
        self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
185
186
187
188
        if (
            self.vllm_config.cache_config.enable_prefix_caching
            or self.scheduler.get_kv_connector() is not None
        ):
189
            caching_hash_fn = get_hash_fn_by_name(
190
191
                vllm_config.cache_config.prefix_caching_hash_algo
            )
192
193
194
            init_none_hash(caching_hash_fn)

            self.request_block_hasher = get_request_block_hasher(
195
                scheduler_block_size, caching_hash_fn
196
            )
197

198
199
200
        self.step_fn = (
            self.step if self.batch_queue is None else self.step_with_batch_queue
        )
201

202
    def _initialize_kv_caches(
203
204
        self, vllm_config: VllmConfig
    ) -> tuple[int, int, KVCacheConfig]:
205
        start = time.time()
206

207
        # Get all kv cache needed by the model
208
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
209

210
211
        has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
        if has_kv_cache:
212
213
214
            if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
                dp_group = getattr(self, "dp_group", None)
                assert dp_group is not None
215
                self.available_gpu_memory_for_kv_cache = (
216
                    ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
217
218
219
220
                )
                available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
                    kv_cache_specs
                )
221
222
223
            else:
                # Profiles the peak memory usage of the model to determine how
                # much memory can be allocated for kv cache.
224
225
                available_gpu_memory = self.model_executor.determine_available_memory()
                self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
226
227
228
        else:
            # Attention free models don't need memory for kv cache
            available_gpu_memory = [0] * len(kv_cache_specs)
229

230
        assert len(kv_cache_specs) == len(available_gpu_memory)
231

232
233
234
235
        kv_cache_configs = get_kv_cache_configs(
            vllm_config, kv_cache_specs, available_gpu_memory
        )
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
236
        num_gpu_blocks = scheduler_kv_cache_config.num_blocks
237
        num_cpu_blocks = 0
238
239

        # Initialize kv cache and warmup the execution
240
        self.model_executor.initialize_from_config(kv_cache_configs)
241

242
        elapsed = time.time() - start
243
244
245
246
        logger.info(
            ("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
            elapsed,
        )
247
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
248

249
250
251
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.model_executor.supported_tasks

252
253
    def add_request(self, request: Request, request_wave: int = 0):
        """Add request to the scheduler.
254

255
256
257
        `request_wave`: indicate which wave of requests this is expected to
        belong to in DP case
        """
258
259
260
        # Validate the request_id type.
        if not isinstance(request.request_id, str):
            raise TypeError(
261
262
                f"request_id must be a string, got {type(request.request_id)}"
            )
263

264
        if pooling_params := request.pooling_params:
265
            supported_pooling_tasks = [
266
                task for task in self.get_supported_tasks() if task in POOLING_TASKS
267
268
            ]

269
            if pooling_params.task not in supported_pooling_tasks:
270
271
272
273
                raise ValueError(
                    f"Unsupported task: {pooling_params.task!r} "
                    f"Supported tasks: {supported_pooling_tasks}"
                )
274

275
        if request.kv_transfer_params is not None and (
276
277
278
279
280
281
            not self.scheduler.get_kv_connector()
        ):
            logger.warning(
                "Got kv_transfer_params, but no KVConnector found. "
                "Disabling KVTransfer for this request."
            )
Robert Shaw's avatar
Robert Shaw committed
282

283
        self.scheduler.add_request(request)
284

285
    def abort_requests(self, request_ids: list[str]):
286
287
288
289
290
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
291
        self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
292

293
294
295
296
297
298
    def execute_model_with_error_logging(
        self,
        model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
        scheduler_output: SchedulerOutput,
    ) -> ModelRunnerOutput:
        """Execute the model and log detailed info on failure."""
299
        try:
300
            return model_fn(scheduler_output)
301
302
303
304
305
        except Exception as err:
            # We do not want to catch BaseException here since we're only
            # interested in dumping info when the exception is due to an
            # error from execute_model itself.

306
            # NOTE: This method is exception-free
307
308
309
            dump_engine_exception(
                self.vllm_config, scheduler_output, self.scheduler.make_stats()
            )
310
311
            raise err

312
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
313
314
315
316
317
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """
318

319
320
321
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
322
            return {}, False
323
        scheduler_output = self.scheduler.schedule()
324
325
        model_output = self.execute_model_with_error_logging(
            self.model_executor.execute_model,  # type: ignore
326
327
            scheduler_output,
        )
328
        engine_core_outputs = self.scheduler.update_from_output(
329
            scheduler_output, model_output
330
        )
331

332
        return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)
333

334
335
336
337
338
339
340
    def post_step(self, model_executed: bool) -> None:
        if self.use_spec_decode and model_executed:
            # Take the draft token ids.
            draft_token_ids = self.model_executor.take_draft_token_ids()
            if draft_token_ids is not None:
                self.scheduler.update_draft_token_ids(draft_token_ids)

341
    def step_with_batch_queue(
342
        self,
343
    ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
344
345
346
347
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
348
349
350
351
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
352
353
354
355
356
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
357
358
        batch_queue = self.batch_queue
        assert batch_queue is not None
359

360
361
362
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
363
        assert len(batch_queue) < self.batch_queue_size
364

365
366
367
        model_executed = False
        if self.scheduler.has_requests():
            scheduler_output = self.scheduler.schedule()
368
369
            future = self.model_executor.execute_model(scheduler_output, non_block=True)
            batch_queue.appendleft((future, scheduler_output))  # type: ignore[arg-type]
370
371

            model_executed = scheduler_output.total_num_scheduled_tokens > 0
372
373
374
375
376
            if (
                model_executed
                and len(batch_queue) < self.batch_queue_size
                and not batch_queue[-1][0].done()
            ):
377
378
379
380
381
382
383
384
385
386
387
388
389
                # Don't block on next worker response unless the queue is full
                # or there are no more requests to schedule.
                return None, True

        elif not batch_queue:
            # Queue is empty. We should not reach here since this method should
            # only be called when the scheduler contains requests or the queue
            # is non-empty.
            return None, False

        # Block until the next result is available.
        future, scheduler_output = batch_queue.pop()
        model_output = self.execute_model_with_error_logging(
390
391
            lambda _: future.result(), scheduler_output
        )
392

393
        engine_core_outputs = self.scheduler.update_from_output(
394
395
            scheduler_output, model_output
        )
396

397
        return engine_core_outputs, model_executed
398

399
    def shutdown(self):
400
        self.structured_output_manager.clear_backend()
401
402
        if self.model_executor:
            self.model_executor.shutdown()
403
404
        if self.scheduler:
            self.scheduler.shutdown()
405

406
    def profile(self, is_start: bool = True):
407
        self.model_executor.profile(is_start)
408

409
410
    def reset_mm_cache(self):
        # NOTE: Since this is mainly for debugging, we don't attempt to
411
        # re-sync the internal caches (P0 sender, P1 receiver)
412
        if self.scheduler.has_unfinished_requests():
413
414
415
416
            logger.warning(
                "Resetting the multi-modal cache when requests are "
                "in progress may lead to desynced internal caches."
            )
417

418
        # The cache either exists in EngineCore or WorkerWrapperBase
419
420
        if self.mm_receiver_cache is not None:
            self.mm_receiver_cache.clear_cache()
421

422
423
        self.model_executor.reset_mm_cache()

424
425
426
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

427
428
429
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

430
    def wake_up(self, tags: list[str] | None = None):
431
        self.model_executor.wake_up(tags)
432

433
434
435
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

436
    def execute_dummy_batch(self):
437
        self.model_executor.execute_dummy_batch()
438

439
440
441
442
443
444
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

445
    def list_loras(self) -> set[int]:
446
447
448
449
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
450

451
452
453
    def save_sharded_state(
        self,
        path: str,
454
455
        pattern: str | None = None,
        max_size: int | None = None,
456
    ) -> None:
457
458
459
460
461
462
        self.model_executor.save_sharded_state(
            path=path, pattern=pattern, max_size=max_size
        )

    def collective_rpc(
        self,
463
464
        method: str | Callable[..., _R],
        timeout: float | None = None,
465
        args: tuple = (),
466
        kwargs: dict[str, Any] | None = None,
467
468
    ) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args, kwargs)
469

470
471
472
473
474
    def save_tensorized_model(
        self,
        tensorizer_config,
    ) -> None:
        self.model_executor.save_tensorized_model(
475
476
            tensorizer_config=tensorizer_config,
        )
477

478
    def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
479
        """Preprocess the request.
480

481
482
483
        This function could be directly used in input processing thread to allow
        request initialization running in parallel with Model forward
        """
484
485
        # Note on thread safety: no race condition.
        # `mm_receiver_cache` is reset at the end of LLMEngine init,
486
        # and will only be accessed in the input processing thread afterwards.
487
        if self.mm_receiver_cache is not None and request.mm_features:
488
489
490
            request.mm_features = self.mm_receiver_cache.get_and_update_features(
                request.mm_features
            )
491

492
        req = Request.from_engine_core_request(request, self.request_block_hasher)
493
494
495
496
497
498
499
500
501
        if req.use_structured_output:
            # Note on thread safety: no race condition.
            # `grammar_init` is only invoked in input processing thread. For
            # `structured_output_manager`, each request is independent and
            # grammar compilation is async. Scheduler always checks grammar
            # compilation status before scheduling request.
            self.structured_output_manager.grammar_init(req)
        return req, request.current_wave

502
503
504
505

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

506
    ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
507

508
509
    def __init__(
        self,
510
        vllm_config: VllmConfig,
511
        local_client: bool,
512
        handshake_address: str,
513
        executor_class: type[Executor],
514
        log_stats: bool,
515
        client_handshake_address: str | None = None,
516
        engine_index: int = 0,
517
    ):
Rui Qiao's avatar
Rui Qiao committed
518
        self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
519
        self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]()
Rui Qiao's avatar
Rui Qiao committed
520
        executor_fail_callback = lambda: self.input_queue.put_nowait(
521
522
            (EngineCoreRequestType.EXECUTOR_FAILED, b"")
        )
523

Rui Qiao's avatar
Rui Qiao committed
524
525
526
        self.engine_index = engine_index
        identity = self.engine_index.to_bytes(length=2, byteorder="little")
        self.engines_running = False
527

528
529
530
531
532
533
534
        with self._perform_handshakes(
            handshake_address,
            identity,
            local_client,
            vllm_config,
            client_handshake_address,
        ) as addresses:
535
            self.client_count = len(addresses.outputs)
536
537

            # Set up data parallel environment.
538
            self.has_coordinator = addresses.coordinator_output is not None
539
            self.frontend_stats_publish_address = (
540
541
542
543
544
545
546
                addresses.frontend_stats_publish_address
            )
            logger.debug(
                "Has DP Coordinator: %s, stats publish address: %s",
                self.has_coordinator,
                self.frontend_stats_publish_address,
            )
547
            # Only publish request queue stats to coordinator for "internal"
548
            # and "hybrid" LB modes .
549
550
            self.publish_dp_lb_stats = (
                self.has_coordinator
551
552
                and not vllm_config.parallel_config.data_parallel_external_lb
            )
553

554
555
            self._init_data_parallel(vllm_config)

556
557
558
            super().__init__(
                vllm_config, executor_class, log_stats, executor_fail_callback
            )
559

560
561
562
563
564
565
            # Background Threads and Queues for IO. These enable us to
            # overlap ZMQ socket IO with GPU since they release the GIL,
            # and to overlap some serialization/deserialization with the
            # model forward pass.
            # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
            ready_event = threading.Event()
566
567
568
569
570
571
572
573
574
575
            input_thread = threading.Thread(
                target=self.process_input_sockets,
                args=(
                    addresses.inputs,
                    addresses.coordinator_input,
                    identity,
                    ready_event,
                ),
                daemon=True,
            )
576
577
578
579
            input_thread.start()

            self.output_thread = threading.Thread(
                target=self.process_output_sockets,
580
581
582
583
584
585
586
                args=(
                    addresses.outputs,
                    addresses.coordinator_output,
                    self.engine_index,
                ),
                daemon=True,
            )
587
588
589
590
591
592
            self.output_thread.start()

            # Don't complete handshake until DP coordinator ready message is
            # received.
            while not ready_event.wait(timeout=10):
                if not input_thread.is_alive():
593
                    raise RuntimeError("Input socket thread died during startup")
594
595
596
                assert addresses.coordinator_input is not None
                logger.info("Waiting for READY message from DP Coordinator...")

597
598
599
600
601
        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()

602
603
604
        # If enable, attach GC debugger after static variable freeze.
        maybe_attach_gc_debug_callback()

605
606
607
608
        # Enable environment variable cache (e.g. assume no more
        # environment variable overrides after this point)
        enable_envs_cache()

Rui Qiao's avatar
Rui Qiao committed
609
    @contextmanager
610
611
612
613
614
615
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
616
        client_handshake_address: str | None,
Rui Qiao's avatar
Rui Qiao committed
617
    ) -> Generator[EngineZmqAddresses, None, None]:
618
619
620
621
622
        """
        Perform startup handshakes.

        For DP=1 or offline mode, this is with the colocated front-end process.

623
        For DP>1 with internal load-balancing this is with the shared front-end
624
625
        process which may reside on a different node.

626
        For DP>1 with external or hybrid load-balancing, two handshakes are
627
        performed:
628
629
630
631
            - With the rank 0 front-end process which retrieves the
              DP Coordinator ZMQ addresses and DP process group address.
            - With the colocated front-end process which retrieves the
              client input/output socket addresses.
632
633
        with the exception of the rank 0 and colocated engines themselves which
        don't require the second handshake.
634
635
636
637
638
639

        Here, "front-end" process can mean the process containing the engine
        core client (which is the API server process in the case the API
        server is not scaled out), OR the launcher process running the
        run_multi_api_server() function in serve.py.
        """
Rui Qiao's avatar
Rui Qiao committed
640
        input_ctx = zmq.Context()
641
        is_local = local_client and client_handshake_address is None
642
        headless = not local_client
643
644
645
646
647
648
649
650
651
        handshake = self._perform_handshake(
            input_ctx,
            handshake_address,
            identity,
            is_local,
            headless,
            vllm_config,
            vllm_config.parallel_config,
        )
652
653
654
655
        if client_handshake_address is None:
            with handshake as addresses:
                yield addresses
        else:
656
            assert local_client
657
            local_handshake = self._perform_handshake(
658
659
                input_ctx, client_handshake_address, identity, True, False, vllm_config
            )
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
            with handshake as addresses, local_handshake as client_addresses:
                addresses.inputs = client_addresses.inputs
                addresses.outputs = client_addresses.outputs
                yield addresses

        # Update config which may have changed from the handshake
        vllm_config.__post_init__()

    @contextmanager
    def _perform_handshake(
        self,
        ctx: zmq.Context,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
675
        headless: bool,
676
        vllm_config: VllmConfig,
677
        parallel_config_to_update: ParallelConfig | None = None,
678
    ) -> Generator[EngineZmqAddresses, None, None]:
679
680
681
682
683
684
685
686
        with make_zmq_socket(
            ctx,
            handshake_address,
            zmq.DEALER,
            identity=identity,
            linger=5000,
            bind=False,
        ) as handshake_socket:
Rui Qiao's avatar
Rui Qiao committed
687
            # Register engine with front-end.
688
689
690
            addresses = self.startup_handshake(
                handshake_socket, local_client, headless, parallel_config_to_update
            )
Rui Qiao's avatar
Rui Qiao committed
691
692
693
694
            yield addresses

            # Send ready message.
            num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
695
696
697
698
            # We pass back the coordinator stats update address here for the
            # external LB case for our colocated front-end to use (coordinator
            # only runs with rank 0).
            dp_stats_address = self.frontend_stats_publish_address
699
700
701
702
703
704
705
706
707
708
709
710

            # Include config hash for DP configuration validation
            ready_msg = {
                "status": "READY",
                "local": local_client,
                "headless": headless,
                "num_gpu_blocks": num_gpu_blocks,
                "dp_stats_address": dp_stats_address,
            }
            if vllm_config.parallel_config.data_parallel_size > 1:
                ready_msg["parallel_config_hash"] = (
                    vllm_config.parallel_config.compute_hash()
711
                )
712
713

            handshake_socket.send(msgspec.msgpack.encode(ready_msg))
Rui Qiao's avatar
Rui Qiao committed
714

715
    @staticmethod
716
    def startup_handshake(
717
718
        handshake_socket: zmq.Socket,
        local_client: bool,
719
        headless: bool,
720
        parallel_config: ParallelConfig | None = None,
721
    ) -> EngineZmqAddresses:
722
        # Send registration message.
723
        handshake_socket.send(
724
725
726
727
728
729
730
731
            msgspec.msgpack.encode(
                {
                    "status": "HELLO",
                    "local": local_client,
                    "headless": headless,
                }
            )
        )
732
733
734

        # Receive initialization message.
        logger.info("Waiting for init message from front-end.")
735
        if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
736
737
738
739
740
            raise RuntimeError(
                "Did not receive response from front-end "
                f"process within {HANDSHAKE_TIMEOUT_MINS} "
                f"minutes"
            )
741
742
        init_bytes = handshake_socket.recv()
        init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
743
744
            init_bytes, type=EngineHandshakeMetadata
        )
745
746
        logger.debug("Received init message: %s", init_message)

747
748
749
        if parallel_config is not None:
            for key, value in init_message.parallel_config.items():
                setattr(parallel_config, key, value)
750

751
        return init_message.addresses
752
753

    @staticmethod
754
    def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
755
756
        """Launch EngineCore busy loop in background process."""

757
758
759
760
761
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

762
763
764
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

765
766
767
768
769
770
771
772
773
774
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

775
        engine_core: EngineCoreProc | None = None
776
        try:
777
            parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
778
            if parallel_config.data_parallel_size > 1 or dp_rank > 0:
779
                set_process_title("EngineCore", f"DP{dp_rank}")
780
                decorate_logs()
781
782
783
784
785
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
786
                set_process_title("EngineCore")
787
                decorate_logs()
788
789
                engine_core = EngineCoreProc(*args, **kwargs)

790
791
            engine_core.run_busy_loop()

792
        except SystemExit:
793
            logger.debug("EngineCore exiting.")
794
            raise
795
796
797
798
799
800
801
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
802
803
804
805
        finally:
            if engine_core is not None:
                engine_core.shutdown()

806
807
808
    def _init_data_parallel(self, vllm_config: VllmConfig):
        pass

809
810
811
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

812
813
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
814
            # 1) Poll the input queue until there is work to do.
815
816
817
818
819
820
821
822
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
823
824
825
826
827
        while (
            not self.engines_running
            and not self.scheduler.has_requests()
            and not self.batch_queue
        ):
828
829
830
831
832
833
834
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
835
            logger.debug("EngineCore loop active.")
836
837
838
839
840
841

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

842
    def _process_engine_step(self) -> bool:
843
844
845
        """Called only when there are unfinished local requests."""

        # Step the engine core.
846
        outputs, model_executed = self.step_fn()
847
        # Put EngineCoreOutputs into the output queue.
848
        for output in outputs.items() if outputs else ():
849
            self.output_queue.put_nowait(output)
850
851
        # Post-step hook.
        self.post_step(model_executed)
852

853
854
        return model_executed

855
856
857
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
858
        """Dispatch request from client."""
859

860
        if request_type == EngineCoreRequestType.ADD:
861
862
            req, request_wave = request
            self.add_request(req, request_wave)
863
        elif request_type == EngineCoreRequestType.ABORT:
864
            self.abort_requests(request)
865
        elif request_type == EngineCoreRequestType.UTILITY:
866
            client_idx, call_id, method_name, args = request
867
868
869
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
870
871
                result = method(*self._convert_msgspec_args(method, args))
                output.result = UtilityResult(result)
872
873
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
874
875
876
                output.failure_message = (
                    f"Call to {method_name} method failed: {str(e)}"
                )
877
            self.output_queue.put_nowait(
878
879
                (client_idx, EngineCoreOutputs(utility_output=output))
            )
880
881
882
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
883
884
885
            logger.error(
                "Unrecognized input request type encountered: %s", request_type
            )
886
887
888
889

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
890
        arg type, try converting to msgspec object."""
891
892
893
894
895
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
896
897
            msgspec.convert(v, type=p.annotation)
            if isclass(p.annotation)
898
            and issubclass(p.annotation, msgspec.Struct)
899
900
901
902
            and not isinstance(v, p.annotation)
            else v
            for v, p in zip(args, arg_types)
        )
903

904
905
906
907
908
909
910
911
912
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
913
914
915
916
            logger.fatal(
                "vLLM shutdown signal from EngineCore failed "
                "to send. Please report this issue."
            )
917

918
919
920
    def process_input_sockets(
        self,
        input_addresses: list[str],
921
        coord_input_address: str | None,
922
923
924
        identity: bytes,
        ready_event: threading.Event,
    ):
925
926
927
        """Input socket IO thread."""

        # Msgpack serialization decoding.
928
929
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
930

931
932
933
        with ExitStack() as stack, zmq.Context() as ctx:
            input_sockets = [
                stack.enter_context(
934
935
936
937
                    make_zmq_socket(
                        ctx, input_address, zmq.DEALER, identity=identity, bind=False
                    )
                )
938
939
940
941
942
943
                for input_address in input_addresses
            ]
            if coord_input_address is None:
                coord_socket = None
            else:
                coord_socket = stack.enter_context(
944
945
946
947
948
949
950
951
                    make_zmq_socket(
                        ctx,
                        coord_input_address,
                        zmq.XSUB,
                        identity=identity,
                        bind=False,
                    )
                )
952
                # Send subscription message to coordinator.
953
                coord_socket.send(b"\x01")
954
955
956
957
958
959
960

            # Register sockets with poller.
            poller = zmq.Poller()
            for input_socket in input_sockets:
                # Send initial message to each input socket - this is required
                # before the front-end ROUTER socket can send input messages
                # back to us.
961
                input_socket.send(b"")
962
                poller.register(input_socket, zmq.POLLIN)
963

964
            if coord_socket is not None:
965
966
                # Wait for ready message from coordinator.
                assert coord_socket.recv() == b"READY"
967
                poller.register(coord_socket, zmq.POLLIN)
968

969
970
            ready_event.set()
            del ready_event
971
972
973
            while True:
                for input_socket, _ in poller.poll():
                    # (RequestType, RequestData)
974
975
                    type_frame, *data_frames = input_socket.recv_multipart(copy=False)
                    request_type = EngineCoreRequestType(bytes(type_frame.buffer))
976
977

                    # Deserialize the request data.
978
979
980
981
982
                    if request_type == EngineCoreRequestType.ADD:
                        request = add_request_decoder.decode(data_frames)
                        request = self.preprocess_add_request(request)
                    else:
                        request = generic_decoder.decode(data_frames)
983
984
985
986

                    # Push to input queue for core busy loop.
                    self.input_queue.put_nowait((request_type, request))

987
988
989
    def process_output_sockets(
        self,
        output_paths: list[str],
990
        coord_output_path: str | None,
991
992
        engine_index: int,
    ):
993
994
995
        """Output socket IO thread."""

        # Msgpack serialization encoding.
996
        encoder = MsgpackEncoder()
997
998
999
1000
1001
1002
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
1003

1004
1005
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
1006
1007
1008
        with ExitStack() as stack, zmq.Context() as ctx:
            sockets = [
                stack.enter_context(
1009
1010
                    make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)
                )
1011
1012
                for output_path in output_paths
            ]
1013
1014
1015
1016
1017
1018
1019
1020
1021
            coord_socket = (
                stack.enter_context(
                    make_zmq_socket(
                        ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000
                    )
                )
                if coord_output_path is not None
                else None
            )
1022
1023
            max_reuse_bufs = len(sockets) + 1

1024
            while True:
1025
1026
1027
1028
                output = self.output_queue.get()
                if output == EngineCoreProc.ENGINE_CORE_DEAD:
                    for socket in sockets:
                        socket.send(output)
1029
                    break
1030
1031
                assert not isinstance(output, bytes)
                client_index, outputs = output
1032
                outputs.engine_index = engine_index
1033

1034
1035
1036
1037
1038
1039
1040
                if client_index == -1:
                    # Don't reuse buffer for coordinator message
                    # which will be very small.
                    assert coord_socket is not None
                    coord_socket.send_multipart(encoder.encode(outputs))
                    continue

1041
1042
1043
1044
1045
                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
1046
                buffers = encoder.encode_into(outputs, buffer)
1047
1048
1049
                tracker = sockets[client_index].send_multipart(
                    buffers, copy=False, track=True
                )
1050
1051
1052
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
1053
1054
                elif len(reuse_buffers) < max_reuse_bufs:
                    # Limit the number of buffers to reuse.
1055
                    reuse_buffers.append(buffer)
1056
1057
1058
1059
1060
1061
1062
1063
1064


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        vllm_config: VllmConfig,
1065
        local_client: bool,
1066
        handshake_address: str,
1067
1068
        executor_class: type[Executor],
        log_stats: bool,
1069
        client_handshake_address: str | None = None,
1070
    ):
1071
1072
        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
1073
        self.step_counter = 0
1074
        self.current_wave = 0
Rui Qiao's avatar
Rui Qiao committed
1075
        self.last_counts = (0, 0)
1076
1077
1078

        # Initialize the engine.
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1079
1080
1081
1082
1083
1084
1085
1086
1087
        super().__init__(
            vllm_config,
            local_client,
            handshake_address,
            executor_class,
            log_stats,
            client_handshake_address,
            dp_rank,
        )
1088
1089
1090

    def _init_data_parallel(self, vllm_config: VllmConfig):
        # Configure GPUs and stateless process group for data parallel.
1091
        dp_rank = vllm_config.parallel_config.data_parallel_rank
1092
        dp_size = vllm_config.parallel_config.data_parallel_size
1093
1094
1095
1096
1097
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

1098
1099
1100
1101
1102
1103
        if vllm_config.kv_transfer_config is not None:
            # modify the engine_id and append the local_dp_rank to it to ensure
            # that the kv_transfer_config is unique for each DP rank.
            vllm_config.kv_transfer_config.engine_id = (
                f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
            )
1104
1105
1106
1107
            logger.debug(
                "Setting kv_transfer_config.engine_id to %s",
                vllm_config.kv_transfer_config.engine_id,
            )
1108

1109
        self.dp_rank = dp_rank
1110
1111
1112
1113
1114
1115
1116
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

1117
1118
1119
1120
    def add_request(self, request: Request, request_wave: int = 0):
        if self.has_coordinator and request_wave != self.current_wave:
            if request_wave > self.current_wave:
                self.current_wave = request_wave
1121
1122
1123
1124
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
1125
1126
                    (-1, EngineCoreOutputs(start_wave=self.current_wave))
                )
1127

1128
        super().add_request(request, request_wave)
1129

1130
1131
1132
    def _handle_client_request(
        self, request_type: EngineCoreRequestType, request: Any
    ) -> None:
1133
        if request_type == EngineCoreRequestType.START_DP_WAVE:
1134
1135
            new_wave, exclude_eng_index = request
            if exclude_eng_index != self.engine_index and (
1136
1137
                new_wave >= self.current_wave
            ):
1138
1139
                self.current_wave = new_wave
                if not self.engines_running:
1140
                    logger.debug("EngineCore starting idle loop for wave %d.", new_wave)
1141
1142
1143
1144
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

1145
    def _maybe_publish_request_counts(self):
1146
        if not self.publish_dp_lb_stats:
1147
1148
1149
1150
1151
1152
            return

        # Publish our request counts (if they've changed).
        counts = self.scheduler.get_request_counts()
        if counts != self.last_counts:
            self.last_counts = counts
1153
1154
1155
1156
            stats = SchedulerStats(
                *counts, step_counter=self.step_counter, current_wave=self.current_wave
            )
            self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats)))
1157

1158
1159
1160
1161
1162
1163
1164
1165
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

1166
1167
            # 2) Step the engine core.
            executed = self._process_engine_step()
1168
1169
            self._maybe_publish_request_counts()

1170
            local_unfinished_reqs = self.scheduler.has_unfinished_requests()
1171
1172
            if not executed:
                if not local_unfinished_reqs and not self.engines_running:
1173
1174
1175
                    # All engines are idle.
                    continue

1176
1177
                # We are in a running state and so must execute a dummy pass
                # if the model didn't execute any ready requests.
1178
1179
1180
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
1181
            self.engines_running = self._has_global_unfinished_reqs(
1182
1183
                local_unfinished_reqs
            )
1184

1185
            if not self.engines_running:
1186
                if self.dp_rank == 0 or not self.has_coordinator:
1187
                    # Notify client that we are pausing the loop.
1188
1189
1190
                    logger.debug(
                        "Wave %d finished, pausing engine loop.", self.current_wave
                    )
1191
1192
1193
1194
                    # In the coordinator case, dp rank 0 sends updates to the
                    # coordinator. Otherwise (offline spmd case), each rank
                    # sends the update to its colocated front-end process.
                    client_index = -1 if self.has_coordinator else 0
1195
                    self.output_queue.put_nowait(
1196
1197
1198
1199
1200
                        (
                            client_index,
                            EngineCoreOutputs(wave_complete=self.current_wave),
                        )
                    )
1201
                # Increment wave count and reset step counter.
1202
                self.current_wave += 1
1203
                self.step_counter = 0
1204
1205

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
1206
        # Optimization - only perform finish-sync all-reduce every 32 steps.
1207
1208
        self.step_counter += 1
        if self.step_counter % 32 != 0:
1209
1210
            return True

1211
        return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished)
Rui Qiao's avatar
Rui Qiao committed
1212

1213
    def reinitialize_distributed(
1214
1215
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
1216
1217
1218
1219
1220
        stateless_destroy_torch_distributed_process_group(self.dp_group)
        self.shutdown()

        parallel_config = self.vllm_config.parallel_config
        old_dp_size = parallel_config.data_parallel_size
1221
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
1222
        if reconfig_request.new_data_parallel_rank != -1:
1223
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
1224
        # local rank specifies device visibility, it should not be changed
1225
1226
1227
1228
1229
        assert (
            reconfig_request.new_data_parallel_rank_local
            == ReconfigureRankType.KEEP_CURRENT_RANK
        )
        parallel_config.data_parallel_master_ip = (
1230
            reconfig_request.new_data_parallel_master_ip
1231
1232
        )
        parallel_config.data_parallel_master_port = (
1233
            reconfig_request.new_data_parallel_master_port
1234
        )
1235
1236
1237
        if reconfig_request.new_data_parallel_rank != -2:
            self.dp_rank = parallel_config.data_parallel_rank
            self.dp_group = parallel_config.stateless_init_dp_group()
1238
        reconfig_request.new_data_parallel_master_port = (
1239
            parallel_config.data_parallel_master_port
1240
        )
1241
1242
1243
1244
1245
1246
1247
1248

        self.model_executor.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_size > old_dp_size:
            assert self.available_gpu_memory_for_kv_cache > 0
            # pass available_gpu_memory_for_kv_cache from existing
            # engine-cores to new engine-cores so they can directly
            # use it in _initialize_kv_caches() rather than profiling.
            ParallelConfig.sync_kv_cache_memory_size(
1249
1250
                self.dp_group, self.available_gpu_memory_for_kv_cache
            )
1251
1252
1253
            # NOTE(yongji): newly joined workers require dummy_run even
            # CUDA graph is not used
            self.model_executor.collective_rpc("compile_or_warm_up_model")
1254
1255
1256
1257
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
1258
1259
1260
            self.shutdown()
            logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
        else:
1261
1262
1263
            logger.info(
                "Distributed environment reinitialized for DP rank %s", self.dp_rank
            )
1264

Rui Qiao's avatar
Rui Qiao committed
1265
1266
1267
1268
1269
1270
1271
1272
1273

class DPEngineCoreActor(DPEngineCoreProc):
    """
    Ray actor for running EngineCore in a data parallel context
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
1274
        local_client: bool,
Rui Qiao's avatar
Rui Qiao committed
1275
1276
1277
1278
1279
1280
1281
1282
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        dp_rank: int = 0,
        local_dp_rank: int = 0,
    ):
        self.addresses = addresses
        vllm_config.parallel_config.data_parallel_rank = dp_rank
1283
        vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
Rui Qiao's avatar
Rui Qiao committed
1284

1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
        # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
        # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
        # and this cannot be done in the same way for Ray because:
        # 1) Ray manages life cycle of all ray workers (including
        # DPEngineCoreActor)
        # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
        # To bypass 2, we need to also set
        # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
        # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
        # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
1295
1296
1297
1298
1299
1300
1301
        # This is problematic because when the vLLM worker (a Ray actor)
        # executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
        # rather than directly using the GPU ID, potentially resulting in
        # index out of bounds error. See:
        # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
        # and get_accelerator_ids_for_accelerator_resource() in worker.py
        # of ray.
1302
        self._set_visible_devices(vllm_config, local_dp_rank)
Rui Qiao's avatar
Rui Qiao committed
1303

1304
        super().__init__(vllm_config, local_client, "", executor_class, log_stats)
Rui Qiao's avatar
Rui Qiao committed
1305

1306
    def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int):
1307
        from vllm.platforms import current_platform
1308

1309
1310
1311
1312
        if current_platform.is_xpu():
            pass
        else:
            device_control_env_var = current_platform.device_control_env_var
1313
1314
1315
            self._set_cuda_visible_devices(
                vllm_config, local_dp_rank, device_control_env_var
            )
1316

1317
1318
1319
    def _set_cuda_visible_devices(
        self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str
    ):
1320
1321
1322
        world_size = vllm_config.parallel_config.world_size
        # Set CUDA_VISIBLE_DEVICES or equivalent.
        try:
1323
1324
1325
            value = get_device_indices(
                device_control_env_var, local_dp_rank, world_size
            )
1326
            os.environ[device_control_env_var] = value
1327
1328
1329
1330
1331
        except IndexError as e:
            raise Exception(
                f"Error setting {device_control_env_var}: "
                f"local range: [{local_dp_rank * world_size}, "
                f"{(local_dp_rank + 1) * world_size}) "
1332
1333
                f'base value: "{os.getenv(device_control_env_var)}"'
            ) from e
1334

Rui Qiao's avatar
Rui Qiao committed
1335
    @contextmanager
1336
1337
1338
1339
1340
1341
    def _perform_handshakes(
        self,
        handshake_address: str,
        identity: bytes,
        local_client: bool,
        vllm_config: VllmConfig,
1342
        client_handshake_address: str | None,
1343
    ):
Rui Qiao's avatar
Rui Qiao committed
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
        """
        For Ray, we don't need to actually perform handshake.
        All addresses information is known before the actor creation.
        Therefore, we simply yield these addresses.
        """
        yield self.addresses

    def wait_for_init(self):
        """
        Wait until the engine core is initialized.

        This is just an empty method. When ray.get() on this method
        (or any other method of the actor) returns, it is guaranteed
        that actor creation (i.e., __init__) is complete.
        """
        pass

    def run(self):
        """
        Run the engine core busy loop.
        """
        try:
            self.run_busy_loop()
        except SystemExit:
            logger.debug("EngineCore exiting.")
            raise
        except Exception:
            logger.exception("EngineCore encountered a fatal error.")
            raise
        finally:
            self.shutdown()