async_llm.py 30.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
from collections.abc import AsyncGenerator, Iterable, Mapping
8
from copy import copy
9
from typing import Any
10

11
import numpy as np
12
import torch
13

14
import vllm.envs as envs
15
from vllm.config import VllmConfig
16
17
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
18
from vllm.entrypoints.utils import _validate_truncation_size
19
from vllm.inputs import PromptType
20
21
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
22
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
23
from vllm.outputs import PoolingRequestOutput, RequestOutput
24
from vllm.plugins.io_processors import get_io_processor
25
from vllm.pooling_params import PoolingParams
26
from vllm.sampling_params import SamplingParams
27
from vllm.tasks import SupportedTask
28
from vllm.tracing import init_tracer
29
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
30
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
31
from vllm.usage.usage_lib import UsageContext
32
from vllm.utils import Device, as_list, cdiv
33
34
from vllm.utils.asyncio import cancel_task_threadsafe
from vllm.utils.functools import deprecate_kwargs
35
from vllm.v1.engine import EngineCoreRequest
36
from vllm.v1.engine.core_client import EngineCoreClient
37
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
38
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
39
from vllm.v1.engine.parallel_sampling import ParentRequest
40
from vllm.v1.engine.processor import Processor
41
from vllm.v1.executor.abstract import Executor
42
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
43
from vllm.v1.metrics.prometheus import shutdown_prometheus
44
from vllm.v1.metrics.stats import IterationStats
45
46
47
48
49
50
51
52

logger = init_logger(__name__)


class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
53
        executor_class: type[Executor],
54
55
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
56
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
57
58
59
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
60
        stat_loggers: list[StatLoggerFactory] | None = None,
61
        aggregate_engine_logging: bool = False,
62
        client_addresses: dict[str, str] | None = None,
63
        client_count: int = 1,
64
        client_index: int = 0,
65
    ) -> None:
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
86
87
88
89
90
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
91
92
                "VLLM_USE_V1=0 or 1 and report this issue on Github."
            )
93

94
95
96
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

97
        self.model_config = vllm_config.model_config
98
        self.vllm_config = vllm_config
99
        self.observability_config = vllm_config.observability_config
100
        self.log_requests = log_requests
101
102
103
104
105

        self.log_stats = log_stats or (stat_loggers is not None)
        if not log_stats and stat_loggers is not None:
            logger.info(
                "AsyncLLM created with log_stats=False and non-empty custom "
106
107
                "logger list; enabling logging without default stat loggers"
            )
108

109
110
111
112
113
114
115
116
117
118
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = init_tokenizer_from_configs(self.model_config)

        self.processor = Processor(self.vllm_config, tokenizer)
        self.io_processor = get_io_processor(
            self.vllm_config,
            self.model_config.io_processor_plugin,
        )
119

120
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
121
122
123
        self.output_processor = OutputProcessor(
            self.tokenizer, log_stats=self.log_stats
        )
124
125
        if self.observability_config.otlp_traces_endpoint is not None:
            tracer = init_tracer(
126
127
                "vllm.llm_engine", self.observability_config.otlp_traces_endpoint
            )
128
            self.output_processor.tracer = tracer
129
130

        # EngineCore (starts the engine in background process).
131
        self.engine_core = EngineCoreClient.make_async_mp_client(
132
133
            vllm_config=vllm_config,
            executor_class=executor_class,
134
            log_stats=self.log_stats,
135
            client_addresses=client_addresses,
136
            client_count=client_count,
137
            client_index=client_index,
138
        )
139
140

        # Loggers.
141
        self.logger_manager: StatLoggerManager | None = None
142
143
144
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
145
                engine_idxs=self.engine_core.engine_ranks_managed,
146
                custom_stat_loggers=stat_loggers,
147
                enable_default_loggers=log_stats,
148
                client_count=client_count,
149
                aggregate_engine_logging=aggregate_engine_logging,
150
151
152
            )
            self.logger_manager.log_engine_initialized()

153
        self.output_handler: asyncio.Task | None = None
154
155
156
157
158
159
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
160

161
162
163
        if envs.VLLM_TORCH_PROFILER_DIR:
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
164
165
                envs.VLLM_TORCH_PROFILER_DIR,
            )
166
167
168
169
170
171
172
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
173
174
175
                    envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True
                ),
            )
176
177
178
        else:
            self.profiler = None

179
    @classmethod
180
181
    @deprecate_kwargs(
        "disable_log_requests",
182
183
184
        additional_message=(
            "This argument will have no effect. Use `enable_log_requests` instead."
        ),
185
    )
186
    def from_vllm_config(
187
188
189
190
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
191
        stat_loggers: list[StatLoggerFactory] | None = None,
192
        enable_log_requests: bool = False,
193
        aggregate_engine_logging: bool = False,
194
        disable_log_stats: bool = False,
195
        client_addresses: dict[str, str] | None = None,
196
197
198
        client_count: int = 1,
        client_index: int = 0,
        disable_log_requests: bool = True,  # Deprecated, will be removed
199
200
201
202
203
204
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
205
206
                "VLLM_USE_V1=0 or 1 and report this issue on Github."
            )
207
208
209
210
211
212

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
213
            stat_loggers=stat_loggers,
214
            log_requests=enable_log_requests,
215
            log_stats=not disable_log_stats,
216
            aggregate_engine_logging=aggregate_engine_logging,
217
            usage_context=usage_context,
218
            client_addresses=client_addresses,
219
            client_count=client_count,
220
            client_index=client_index,
221
222
        )

223
224
225
226
227
228
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
229
        stat_loggers: list[StatLoggerFactory] | None = None,
230
    ) -> "AsyncLLM":
231
232
233
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
234
        vllm_config = engine_args.create_engine_config(usage_context)
235
        executor_class = Executor.get_class(vllm_config)
236
237
238
239
240

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
241
            log_requests=engine_args.enable_log_requests,
242
243
244
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
245
            stat_loggers=stat_loggers,
246
247
        )

248
249
250
    def __del__(self):
        self.shutdown()

251
252
253
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

254
255
        shutdown_prometheus()

256
257
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
258

259
        cancel_task_threadsafe(getattr(self, "output_handler", None))
260

261
262
263
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

264
265
266
    async def add_request(
        self,
        request_id: str,
267
268
269
270
271
272
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
273
        priority: int = 0,
274
275
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
276
    ) -> RequestOutputCollector:
277
278
        """Add new request to the AsyncLLM."""

279
280
281
        if self.errored:
            raise EngineDeadError()

282
        is_pooling = isinstance(params, PoolingParams)
283
284
285

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
286

287
        # Convert Input --> Request.
288
289
290
291
292
293
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
        else:
            assert prompt_text is None
            logger.warning_once(
                "Processor has been moved under OpenAIServing and will "
294
295
296
297
298
299
300
301
302
303
304
305
306
307
                "be removed from AsyncLLM in v0.13."
            )
            request = self.processor.process_inputs(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )
            prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
308

309
        if is_pooling or params.n == 1:
310
            await self._add_request(request, prompt_text, None, 0, queue)
311
312
            return queue

313
314
315
316
317
        # Get the updated SamplingParams from the request, which
        # were cloned/updated in processor.process_inputs above.
        parent_params = request.sampling_params
        assert parent_params is not None

318
        # Fan out child requests (for n>1).
319
320
321
        parent_request = ParentRequest(request_id, parent_params)
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
322
            child_request = request if idx == parent_params.n - 1 else copy(request)
323
            child_request.request_id = request_id
324
            child_request.sampling_params = child_params
325
326
327
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
328
        return queue
329

330
331
332
    async def _add_request(
        self,
        request: EngineCoreRequest,
333
334
        prompt: str | None,
        parent_req: ParentRequest | None,
335
336
337
        index: int,
        queue: RequestOutputCollector,
    ):
338
        # Add the request to OutputProcessor (this process).
339
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
340

341
342
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
343

344
345
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
346
347
348
349
350
351

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
352
    async def generate(
353
        self,
354
        prompt: EngineCoreRequest | PromptType,
355
356
        sampling_params: SamplingParams,
        request_id: str,
357
        *,
358
359
360
361
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
362
        priority: int = 0,
363
        data_parallel_rank: int | None = None,
364
365
366
367
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
368
            * 2) Processing the Input.
369
370
371
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

372
373
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
374
375
376
377
378
379
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

380
381
382
383
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and sampling_params.prompt_logprobs
        ):
384
385
386
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
387
388
                "prompt logprobs"
            )
389

390
391
392
393
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
394
            self._run_output_handler()
395

396
397
398
399
400
401
402
403
404
405
            if tokenization_kwargs is None:
                tokenization_kwargs = {}
                truncate_prompt_tokens = sampling_params.truncate_prompt_tokens

                _validate_truncation_size(
                    self.model_config.max_model_len,
                    truncate_prompt_tokens,
                    tokenization_kwargs,
                )

406
407
408
409
410
411
412
413
414
415
416
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
417

418
419
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
420
421
            finished = False
            while not finished:
422
423
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
424
                out = q.get_nowait() or await q.get()
425

426
                # Note: both OutputProcessor and EngineCore handle their
427
                # own request cleanup based on finished.
428
                finished = out.finished
429
430
                yield out

431
        # If the request is disconnected by the client, generate()
432
433
434
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
435
            await self.abort(request_id)
436
437
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
438
            raise
439

440
441
442
443
444
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
445

446
447
448
449
450
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
451

452
        # Unexpected error in the generate() task (possibly recoverable).
453
        except Exception as e:
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
470
        logger_manager = self.logger_manager
471
        processor = self.processor
472
473
474
475
476
477
478
479

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

480
481
482
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
483
484
485
486

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
487
                    if num_outputs <= envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
488
                        slices = (outputs.outputs,)
489
490
491
                    else:
                        slices = np.array_split(
                            outputs.outputs,
492
                            cdiv(num_outputs, envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE),
493
                        )
494
495
496
497

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
498
499
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
500
501
502
503
504
505
506
507
508
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
509
510
                            processed_outputs.reqs_to_abort
                        )
511
512
513
514

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
515
516
517
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
518
519
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
520
                            mm_cache_stats=processor.stat_mm_cache(),
521
522
523
524
525
526
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
527

528
    async def abort(self, request_id: str | Iterable[str]) -> None:
529
        """Abort RequestId in OutputProcessor and EngineCore."""
530

531
532
533
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
534
535
        all_request_ids = self.output_processor.abort_requests(request_ids)
        await self.engine_core.abort_requests_async(all_request_ids)
536

537
        if self.log_requests:
538
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
539

540
    async def encode(
541
542
543
544
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
545
546
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
547
        priority: int = 0,
548
549
        truncate_prompt_tokens: int | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            self._run_output_handler()

571
            if tokenization_kwargs is None:
572
                tokenization_kwargs = {}
573
574
575
576
577
578
            _validate_truncation_size(
                self.model_config.max_model_len,
                truncate_prompt_tokens,
                tokenization_kwargs,
            )

579
580
581
582
583
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
584
                tokenization_kwargs=tokenization_kwargs,
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
628

629
    @property
630
    def tokenizer(self) -> AnyTokenizer | None:
631
        return self.processor.tokenizer
632

633
    @tokenizer.setter
634
    def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
635
        self.processor.tokenizer = tokenizer
636

637
    async def get_tokenizer(self) -> AnyTokenizer:
638
        if self.tokenizer is None:
639
640
641
            raise ValueError(
                "Unable to get tokenizer because skip_tokenizer_init is True"
            )
642

643
        return self.tokenizer
644
645

    async def is_tracing_enabled(self) -> bool:
646
        return self.observability_config.otlp_traces_endpoint is not None
647

648
    async def do_log_stats(self) -> None:
649
650
        if self.logger_manager:
            self.logger_manager.log()
651
652
653

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
654
655
        if self.errored:
            raise self.dead_error
656
657

    async def start_profile(self) -> None:
658
659
660
661
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
662
663

    async def stop_profile(self) -> None:
664
665
666
667
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
668

669
    async def reset_mm_cache(self) -> None:
670
        self.processor.clear_mm_cache()
671
672
        await self.engine_core.reset_mm_cache_async()

673
    async def reset_prefix_cache(self, device: Device | None = None) -> None:
674
675
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
676
677
        await self.engine_core.reset_prefix_cache_async()

678
    async def sleep(self, level: int = 1) -> None:
679
        await self.reset_prefix_cache()
680
681
        await self.engine_core.sleep_async(level)

682
    async def wake_up(self, tags: list[str] | None = None) -> None:
683
        await self.engine_core.wake_up_async(tags)
684

685
686
687
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

688
    async def add_lora(self, lora_request: LoRARequest) -> bool:
689
        """Load a new LoRA adapter into the engine for future requests."""
690
691
692
693
694
695
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

696
    async def list_loras(self) -> set[int]:
697
698
699
700
701
702
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
703

704
705
706
    async def collective_rpc(
        self,
        method: str,
707
        timeout: float | None = None,
708
        args: tuple = (),
709
        kwargs: dict | None = None,
710
    ):
711
712
713
714
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
715
716
            method, timeout, args, kwargs
        )
717

718
719
720
721
722
723
724
725
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

726
            logger.info("Engines are still running, waiting for requests to drain...")
727
728
            await asyncio.sleep(1)  # Wait 1 second before checking again

729
730
731
732
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
733

734
735
736
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
737
738
739
740
741
742
743
744
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
745
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
746
        if old_data_parallel_size == new_data_parallel_size:
747
748
749
750
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
751
752
            return
        logger.info(
753
754
755
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
756
757
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
758
759
760
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
761
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
762
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
763
764

        # recreate stat loggers
765
766
767
768
769
770
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
771
                vllm_config=self.vllm_config,
772
                engine_idxs=list(range(new_data_parallel_size)),
773
774
775
                custom_stat_loggers=None,
            )

776
777
    @property
    def is_running(self) -> bool:
778
779
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
780
781
782

    @property
    def is_stopped(self) -> bool:
783
        return self.errored
784
785
786

    @property
    def errored(self) -> bool:
787
        return self.engine_core.resources.engine_dead or not self.is_running
788
789
790

    @property
    def dead_error(self) -> BaseException:
791
        return EngineDeadError()