async_llm.py 29.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
from collections.abc import AsyncGenerator, Iterable, Mapping
8
from copy import copy
9
from typing import Any, Optional, Union
10

11
import numpy as np
12
import torch
13

14
import vllm.envs as envs
15
16
17
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
18
from vllm.entrypoints.utils import _validate_truncation_size
19
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
20
from vllm.inputs import PromptType
21
from vllm.inputs.preprocess import InputPreprocessor
22
23
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
24
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
25
from vllm.outputs import PoolingRequestOutput, RequestOutput
26
from vllm.pooling_params import PoolingParams
27
from vllm.sampling_params import SamplingParams
28
from vllm.tasks import SupportedTask
29
from vllm.tracing import init_tracer
30
31
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
32
33
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               init_tokenizer_from_configs)
34
from vllm.usage.usage_lib import UsageContext
35
36
from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv,
                        deprecate_kwargs)
37
from vllm.v1.engine import EngineCoreRequest
38
from vllm.v1.engine.core_client import EngineCoreClient
39
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
40
41
from vllm.v1.engine.output_processor import (OutputProcessor,
                                             RequestOutputCollector)
42
from vllm.v1.engine.parallel_sampling import ParentRequest
43
from vllm.v1.engine.processor import Processor
44
from vllm.v1.executor.abstract import Executor
45
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
46
from vllm.v1.metrics.prometheus import shutdown_prometheus
47
from vllm.v1.metrics.stats import IterationStats
48
49
50
51
52
53
54
55
56

logger = init_logger(__name__)


class AsyncLLM(EngineClient):

    def __init__(
        self,
        vllm_config: VllmConfig,
57
        executor_class: type[Executor],
58
59
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
60
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
61
62
63
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
64
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
65
        client_addresses: Optional[dict[str, str]] = None,
66
        client_count: int = 1,
67
        client_index: int = 0,
68
    ) -> None:
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
89
90
91
92
93
94
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
95

96
97
98
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

99
        self.model_config = vllm_config.model_config
100
        self.vllm_config = vllm_config
101
        self.observability_config = vllm_config.observability_config
102
        self.log_requests = log_requests
103
104
105
106
107
108

        self.log_stats = log_stats or (stat_loggers is not None)
        if not log_stats and stat_loggers is not None:
            logger.info(
                "AsyncLLM created with log_stats=False and non-empty custom "
                "logger list; enabling logging without default stat loggers")
109

110
111
112
113
114
        if self.model_config.skip_tokenizer_init:
            self.tokenizer = None
        else:
            # Tokenizer (+ ensure liveness if running in another process).
            self.tokenizer = init_tokenizer_from_configs(
115
                model_config=vllm_config.model_config)
116
117

        # Processor (converts Inputs --> EngineCoreRequests).
118
        self.processor = Processor(
119
            vllm_config=vllm_config,
120
            tokenizer=self.tokenizer,
121
            mm_registry=mm_registry,
122
        )
123

124
125
126
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
                                                log_stats=self.log_stats)
127
128
129
130
131
        if self.observability_config.otlp_traces_endpoint is not None:
            tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)
            self.output_processor.tracer = tracer
132
133

        # EngineCore (starts the engine in background process).
134
        self.engine_core = EngineCoreClient.make_async_mp_client(
135
136
            vllm_config=vllm_config,
            executor_class=executor_class,
137
            log_stats=self.log_stats,
138
            client_addresses=client_addresses,
139
            client_count=client_count,
140
            client_index=client_index,
141
        )
142
143
144
145
146
147

        # Loggers.
        self.logger_manager: Optional[StatLoggerManager] = None
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
148
                engine_idxs=self.engine_core.engine_ranks_managed,
149
                custom_stat_loggers=stat_loggers,
150
                enable_default_loggers=log_stats,
151
                client_count=client_count,
152
153
154
            )
            self.logger_manager.log_engine_initialized()

155
        self.output_handler: Optional[asyncio.Task] = None
156
157
158
159
160
161
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        if envs.VLLM_TORCH_PROFILER_DIR:
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
                envs.VLLM_TORCH_PROFILER_DIR)
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
                    envs.VLLM_TORCH_PROFILER_DIR,
                    worker_name=worker_name,
                    use_gzip=True))
        else:
            self.profiler = None

180
    @classmethod
181
182
183
184
185
    @deprecate_kwargs(
        "disable_log_requests",
        additional_message=("This argument will have no effect. "
                            "Use `enable_log_requests` instead."),
    )
186
    def from_vllm_config(
187
188
189
190
191
192
193
194
            cls,
            vllm_config: VllmConfig,
            start_engine_loop: bool = True,
            usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
            stat_loggers: Optional[list[StatLoggerFactory]] = None,
            enable_log_requests: bool = False,
            disable_log_stats: bool = False,
            client_addresses: Optional[dict[str, str]] = None,
195
            client_count: int = 1,
196
197
            client_index: int = 0,
            disable_log_requests: bool = True,  # Deprecated, will be removed
198
199
200
201
202
203
204
205
206
207
208
209
210
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
211
            stat_loggers=stat_loggers,
212
            log_requests=enable_log_requests,
213
214
            log_stats=not disable_log_stats,
            usage_context=usage_context,
215
            client_addresses=client_addresses,
216
            client_count=client_count,
217
            client_index=client_index,
218
219
        )

220
221
222
223
224
225
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
226
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
227
    ) -> "AsyncLLM":
228
229
230
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
231
        vllm_config = engine_args.create_engine_config(usage_context)
232
        executor_class = Executor.get_class(vllm_config)
233
234
235
236
237

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
238
            log_requests=engine_args.enable_log_requests,
239
240
241
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
242
            stat_loggers=stat_loggers,
243
244
        )

245
246
247
    def __del__(self):
        self.shutdown()

248
249
250
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

251
252
        shutdown_prometheus()

253
254
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
255

256
        cancel_task_threadsafe(getattr(self, "output_handler", None))
257

258
259
260
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

261
262
263
264
265
266
267
    async def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
268
        tokenization_kwargs: Optional[dict[str, Any]] = None,
269
270
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
271
        data_parallel_rank: Optional[int] = None,
272
    ) -> RequestOutputCollector:
273
274
        """Add new request to the AsyncLLM."""

275
276
277
        if self.errored:
            raise EngineDeadError()

278
        is_pooling = isinstance(params, PoolingParams)
279
280
281

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
282

283
        # Convert Input --> Request.
284
285
        prompt_str, request = self.processor.process_inputs(
            request_id, prompt, params, arrival_time, lora_request,
286
            tokenization_kwargs, trace_headers, priority, data_parallel_rank)
287

288
        if is_pooling or params.n == 1:
289
            await self._add_request(request, prompt_str, None, 0, queue)
290
291
            return queue

292
293
294
295
296
        # Get the updated SamplingParams from the request, which
        # were cloned/updated in processor.process_inputs above.
        parent_params = request.sampling_params
        assert parent_params is not None

297
        # Fan out child requests (for n>1).
298
299
300
301
302
        parent_request = ParentRequest(request_id, parent_params)
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
            child_request = request if idx == parent_params.n - 1 else copy(
                request)
303
            child_request.request_id = request_id
304
            child_request.sampling_params = child_params
305
306
            await self._add_request(child_request, prompt_str, parent_request,
                                    idx, queue)
307
        return queue
308

309
    async def _add_request(self, request: EngineCoreRequest,
310
                           prompt: Optional[str],
311
                           parent_req: Optional[ParentRequest], index: int,
312
                           queue: RequestOutputCollector):
313

314
        # Add the request to OutputProcessor (this process).
315
316
        self.output_processor.add_request(request, prompt, parent_req, index,
                                          queue)
317

318
319
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
320

321
322
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
323
324
325
326
327
328

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
329
    async def generate(
330
331
332
333
334
335
336
        self,
        prompt: PromptType,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
337
        data_parallel_rank: Optional[int] = None,
338
339
340
341
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
342
            * 2) Processing the Input.
343
344
345
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

346
347
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
348
349
350
351
352
353
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

354
355
356
357
358
359
360
        if (self.vllm_config.cache_config.kv_sharing_fast_prefill
                and sampling_params.prompt_logprobs):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs")

361
362
363
364
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
365
            self._run_output_handler()
366

367
368
369
370
371
372
373
374
375
            tokenization_kwargs: dict[str, Any] = {}
            truncate_prompt_tokens = sampling_params.truncate_prompt_tokens

            _validate_truncation_size(
                self.model_config.max_model_len,
                truncate_prompt_tokens,
                tokenization_kwargs,
            )

376
            q = await self.add_request(
377
378
379
380
381
382
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
383
                tokenization_kwargs=tokenization_kwargs,
384
                data_parallel_rank=data_parallel_rank,
385
            )
386

387
388
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
389
390
            finished = False
            while not finished:
391
392
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
393
                out = q.get_nowait() or await q.get()
394

395
                # Note: both OutputProcessor and EngineCore handle their
396
                # own request cleanup based on finished.
397
                finished = out.finished
398
399
                yield out

400
        # If the request is disconnected by the client, generate()
401
402
403
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
404
            await self.abort(request_id)
405
406
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
407
            raise
408

409
410
411
412
413
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
414

415
416
417
418
419
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
420

421
        # Unexpected error in the generate() task (possibly recoverable).
422
        except Exception as e:
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
439
        logger_manager = self.logger_manager
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

                    iteration_stats = IterationStats() if (
                        log_stats and num_outputs) else None

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
                    if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
                        slices = (outputs.outputs, )
                    else:
                        slices = np.array_split(
                            outputs.outputs,
                            cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
                            outputs_slice, outputs.timestamp, iteration_stats)
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
                            processed_outputs.reqs_to_abort)

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
479
480
481
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
482
483
484
485
486
487
488
489
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
490

491
    async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
492
        """Abort RequestId in OutputProcessor and EngineCore."""
493

494
495
496
497
        request_ids = (request_id, ) if isinstance(
            request_id, str) else as_list(request_id)
        all_request_ids = self.output_processor.abort_requests(request_ids)
        await self.engine_core.abort_requests_async(all_request_ids)
498

499
        if self.log_requests:
500
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
501

502
    async def encode(
503
504
505
506
507
508
509
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
510
        truncate_prompt_tokens: Optional[int] = None,
511
        tokenization_kwargs: Optional[dict[str, Any]] = None,
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            self._run_output_handler()

533
534
535
536
537
538
539
540
            if tokenization_kwargs is None:
                tokenization_kwargs = dict[str, Any]()
            _validate_truncation_size(
                self.model_config.max_model_len,
                truncate_prompt_tokens,
                tokenization_kwargs,
            )

541
542
543
544
545
546
547
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
548
                tokenization_kwargs=tokenization_kwargs,
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
590

591
592
593
    async def get_vllm_config(self) -> VllmConfig:
        return self.vllm_config

594
595
596
    async def get_model_config(self) -> ModelConfig:
        return self.model_config

597
598
599
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.processor.input_preprocessor

600
    async def get_tokenizer(self) -> AnyTokenizer:
601
602
603
604
        if self.tokenizer is None:
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")

605
        return self.tokenizer
606
607

    async def is_tracing_enabled(self) -> bool:
608
        return self.observability_config.otlp_traces_endpoint is not None
609

610
    async def do_log_stats(self) -> None:
611
612
        if self.logger_manager:
            self.logger_manager.log()
613
614
615

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
616
617
        if self.errored:
            raise self.dead_error
618
619

    async def start_profile(self) -> None:
620
621
622
623
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
624
625

    async def stop_profile(self) -> None:
626
627
628
629
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
630

631
    async def reset_mm_cache(self) -> None:
632
        self.processor.clear_cache()
633
634
        await self.engine_core.reset_mm_cache_async()

635
636
637
638
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
639
640
        await self.engine_core.reset_prefix_cache_async()

641
    async def sleep(self, level: int = 1) -> None:
642
        await self.reset_prefix_cache()
643
644
        await self.engine_core.sleep_async(level)

645
646
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        await self.engine_core.wake_up_async(tags)
647

648
649
650
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

651
    async def add_lora(self, lora_request: LoRARequest) -> bool:
652
        """Load a new LoRA adapter into the engine for future requests."""
653
654
655
656
657
658
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

659
    async def list_loras(self) -> set[int]:
660
661
662
663
664
665
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
666

667
668
669
670
671
672
673
674
675
676
677
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
            method, timeout, args, kwargs)

678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

            logger.info(
                "Engines are still running, waiting for requests to drain...")
            await asyncio.sleep(1)  # Wait 1 second before checking again

        raise TimeoutError(f"Timeout reached after {drain_timeout} seconds "
                           "waiting for requests to drain.")

    async def scale_elastic_ep(self,
                               new_data_parallel_size: int,
                               drain_timeout: int = 300):
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
        old_data_parallel_size = \
            self.vllm_config.parallel_config.data_parallel_size
        if old_data_parallel_size == new_data_parallel_size:
            logger.info("Data parallel size is already %s, skipping scale",
                        new_data_parallel_size)
            return
        logger.info(
            "Waiting for requests to drain before "
            "scaling up to %s engines...", new_data_parallel_size)
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
            "Requests have been drained, proceeding with scale "
            "to %s engines", new_data_parallel_size)
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
        self.vllm_config.parallel_config.data_parallel_size = \
            new_data_parallel_size

        # recreate stat loggers
722
723
724
725
726
727
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
728
                vllm_config=self.vllm_config,
729
                engine_idxs=list(range(new_data_parallel_size)),
730
731
732
                custom_stat_loggers=None,
            )

733
734
    @property
    def is_running(self) -> bool:
735
736
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
737
738
739

    @property
    def is_stopped(self) -> bool:
740
        return self.errored
741
742
743

    @property
    def errored(self) -> bool:
744
        return self.engine_core.resources.engine_dead or not self.is_running
745
746
747

    @property
    def dead_error(self) -> BaseException:
748
        return EngineDeadError()