async_llm.py 30.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
from collections.abc import AsyncGenerator, Iterable, Mapping
8
from copy import copy
9
from typing import Any, Optional, Union
10

11
import numpy as np
12
import torch
13

14
import vllm.envs as envs
15
from vllm.config import VllmConfig
16
17
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
18
from vllm.entrypoints.utils import _validate_truncation_size
19
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
20
from vllm.inputs import PromptType
21
22
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
23
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
24
from vllm.outputs import PoolingRequestOutput, RequestOutput
25
from vllm.plugins.io_processors import get_io_processor
26
from vllm.pooling_params import PoolingParams
27
from vllm.sampling_params import SamplingParams
28
from vllm.tasks import SupportedTask
29
from vllm.tracing import init_tracer
30
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
31
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
32
from vllm.usage.usage_lib import UsageContext
33
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs
34
from vllm.v1.engine import EngineCoreRequest
35
from vllm.v1.engine.core_client import EngineCoreClient
36
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
37
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
38
from vllm.v1.engine.parallel_sampling import ParentRequest
39
from vllm.v1.engine.processor import Processor
40
from vllm.v1.executor.abstract import Executor
41
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
42
from vllm.v1.metrics.prometheus import shutdown_prometheus
43
from vllm.v1.metrics.stats import IterationStats
44
45
46
47
48
49
50
51

logger = init_logger(__name__)


class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
52
        executor_class: type[Executor],
53
54
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
55
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
56
57
58
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
59
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
60
        client_addresses: Optional[dict[str, str]] = None,
61
        client_count: int = 1,
62
        client_index: int = 0,
63
    ) -> None:
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
84
85
86
87
88
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
89
90
                "VLLM_USE_V1=0 or 1 and report this issue on Github."
            )
91

92
93
94
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

95
        self.model_config = vllm_config.model_config
96
        self.vllm_config = vllm_config
97
        self.observability_config = vllm_config.observability_config
98
        self.log_requests = log_requests
99
100
101
102
103

        self.log_stats = log_stats or (stat_loggers is not None)
        if not log_stats and stat_loggers is not None:
            logger.info(
                "AsyncLLM created with log_stats=False and non-empty custom "
104
105
                "logger list; enabling logging without default stat loggers"
            )
106

107
108
109
110
111
112
113
114
115
116
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = init_tokenizer_from_configs(self.model_config)

        self.processor = Processor(self.vllm_config, tokenizer)
        self.io_processor = get_io_processor(
            self.vllm_config,
            self.model_config.io_processor_plugin,
        )
117

118
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
119
120
121
        self.output_processor = OutputProcessor(
            self.tokenizer, log_stats=self.log_stats
        )
122
123
        if self.observability_config.otlp_traces_endpoint is not None:
            tracer = init_tracer(
124
125
                "vllm.llm_engine", self.observability_config.otlp_traces_endpoint
            )
126
            self.output_processor.tracer = tracer
127
128

        # EngineCore (starts the engine in background process).
129
        self.engine_core = EngineCoreClient.make_async_mp_client(
130
131
            vllm_config=vllm_config,
            executor_class=executor_class,
132
            log_stats=self.log_stats,
133
            client_addresses=client_addresses,
134
            client_count=client_count,
135
            client_index=client_index,
136
        )
137
138
139
140
141
142

        # Loggers.
        self.logger_manager: Optional[StatLoggerManager] = None
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
143
                engine_idxs=self.engine_core.engine_ranks_managed,
144
                custom_stat_loggers=stat_loggers,
145
                enable_default_loggers=log_stats,
146
                client_count=client_count,
147
148
149
            )
            self.logger_manager.log_engine_initialized()

150
        self.output_handler: Optional[asyncio.Task] = None
151
152
153
154
155
156
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
157

158
159
160
        if envs.VLLM_TORCH_PROFILER_DIR:
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
161
162
                envs.VLLM_TORCH_PROFILER_DIR,
            )
163
164
165
166
167
168
169
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
170
171
172
                    envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True
                ),
            )
173
174
175
        else:
            self.profiler = None

176
    @classmethod
177
178
    @deprecate_kwargs(
        "disable_log_requests",
179
180
181
        additional_message=(
            "This argument will have no effect. Use `enable_log_requests` instead."
        ),
182
    )
183
    def from_vllm_config(
184
185
186
187
188
189
190
191
192
193
194
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
        enable_log_requests: bool = False,
        disable_log_stats: bool = False,
        client_addresses: Optional[dict[str, str]] = None,
        client_count: int = 1,
        client_index: int = 0,
        disable_log_requests: bool = True,  # Deprecated, will be removed
195
196
197
198
199
200
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
201
202
                "VLLM_USE_V1=0 or 1 and report this issue on Github."
            )
203
204
205
206
207
208

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
209
            stat_loggers=stat_loggers,
210
            log_requests=enable_log_requests,
211
212
            log_stats=not disable_log_stats,
            usage_context=usage_context,
213
            client_addresses=client_addresses,
214
            client_count=client_count,
215
            client_index=client_index,
216
217
        )

218
219
220
221
222
223
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
224
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
225
    ) -> "AsyncLLM":
226
227
228
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
229
        vllm_config = engine_args.create_engine_config(usage_context)
230
        executor_class = Executor.get_class(vllm_config)
231
232
233
234
235

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
236
            log_requests=engine_args.enable_log_requests,
237
238
239
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
240
            stat_loggers=stat_loggers,
241
242
        )

243
244
245
    def __del__(self):
        self.shutdown()

246
247
248
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

249
250
        shutdown_prometheus()

251
252
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
253

254
        cancel_task_threadsafe(getattr(self, "output_handler", None))
255

256
257
258
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

259
260
261
    async def add_request(
        self,
        request_id: str,
262
        prompt: Union[EngineCoreRequest, PromptType],
263
264
265
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
266
        tokenization_kwargs: Optional[dict[str, Any]] = None,
267
268
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
269
        data_parallel_rank: Optional[int] = None,
270
        prompt_text: Optional[str] = None,
271
    ) -> RequestOutputCollector:
272
273
        """Add new request to the AsyncLLM."""

274
275
276
        if self.errored:
            raise EngineDeadError()

277
        is_pooling = isinstance(params, PoolingParams)
278
279
280

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
281

282
        # Convert Input --> Request.
283
284
285
286
287
288
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
        else:
            assert prompt_text is None
            logger.warning_once(
                "Processor has been moved under OpenAIServing and will "
289
290
291
292
293
294
295
296
297
298
299
300
301
302
                "be removed from AsyncLLM in v0.13."
            )
            request = self.processor.process_inputs(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )
            prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
303

304
        if is_pooling or params.n == 1:
305
            await self._add_request(request, prompt_text, None, 0, queue)
306
307
            return queue

308
309
310
311
312
        # Get the updated SamplingParams from the request, which
        # were cloned/updated in processor.process_inputs above.
        parent_params = request.sampling_params
        assert parent_params is not None

313
        # Fan out child requests (for n>1).
314
315
316
        parent_request = ParentRequest(request_id, parent_params)
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
317
            child_request = request if idx == parent_params.n - 1 else copy(request)
318
            child_request.request_id = request_id
319
            child_request.sampling_params = child_params
320
321
322
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
323
        return queue
324

325
326
327
328
329
330
331
332
    async def _add_request(
        self,
        request: EngineCoreRequest,
        prompt: Optional[str],
        parent_req: Optional[ParentRequest],
        index: int,
        queue: RequestOutputCollector,
    ):
333
        # Add the request to OutputProcessor (this process).
334
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
335

336
337
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
338

339
340
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
341
342
343
344
345
346

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
347
    async def generate(
348
        self,
349
        prompt: Union[EngineCoreRequest, PromptType],
350
351
        sampling_params: SamplingParams,
        request_id: str,
352
353
        *,
        prompt_text: Optional[str] = None,
354
        lora_request: Optional[LoRARequest] = None,
355
        tokenization_kwargs: Optional[dict[str, Any]] = None,
356
357
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
358
        data_parallel_rank: Optional[int] = None,
359
360
361
362
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
363
            * 2) Processing the Input.
364
365
366
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

367
368
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
369
370
371
372
373
374
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

375
376
377
378
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and sampling_params.prompt_logprobs
        ):
379
380
381
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
382
383
                "prompt logprobs"
            )
384

385
386
387
388
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
389
            self._run_output_handler()
390

391
392
393
394
395
396
397
398
399
400
            if tokenization_kwargs is None:
                tokenization_kwargs = {}
                truncate_prompt_tokens = sampling_params.truncate_prompt_tokens

                _validate_truncation_size(
                    self.model_config.max_model_len,
                    truncate_prompt_tokens,
                    tokenization_kwargs,
                )

401
402
403
404
405
406
407
408
409
410
411
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
412

413
414
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
415
416
            finished = False
            while not finished:
417
418
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
419
                out = q.get_nowait() or await q.get()
420

421
                # Note: both OutputProcessor and EngineCore handle their
422
                # own request cleanup based on finished.
423
                finished = out.finished
424
425
                yield out

426
        # If the request is disconnected by the client, generate()
427
428
429
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
430
            await self.abort(request_id)
431
432
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
433
            raise
434

435
436
437
438
439
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
440

441
442
443
444
445
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
446

447
        # Unexpected error in the generate() task (possibly recoverable).
448
        except Exception as e:
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
465
        logger_manager = self.logger_manager
466
        processor = self.processor
467
468
469
470
471
472
473
474

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

475
476
477
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
478
479
480
481
482

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
                    if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
483
                        slices = (outputs.outputs,)
484
485
486
                    else:
                        slices = np.array_split(
                            outputs.outputs,
487
488
                            cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE),
                        )
489
490
491
492

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
493
494
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
495
496
497
498
499
500
501
502
503
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
504
505
                            processed_outputs.reqs_to_abort
                        )
506
507
508
509

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
510
511
512
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
513
514
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
515
                            mm_cache_stats=processor.stat_mm_cache(),
516
517
518
519
520
521
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
522

523
    async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
524
        """Abort RequestId in OutputProcessor and EngineCore."""
525

526
527
528
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
529
530
        all_request_ids = self.output_processor.abort_requests(request_ids)
        await self.engine_core.abort_requests_async(all_request_ids)
531

532
        if self.log_requests:
533
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
534

535
    async def encode(
536
537
538
539
540
541
542
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
543
        truncate_prompt_tokens: Optional[int] = None,
544
        tokenization_kwargs: Optional[dict[str, Any]] = None,
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            self._run_output_handler()

566
            if tokenization_kwargs is None:
567
                tokenization_kwargs = {}
568
569
570
571
572
573
            _validate_truncation_size(
                self.model_config.max_model_len,
                truncate_prompt_tokens,
                tokenization_kwargs,
            )

574
575
576
577
578
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
579
                tokenization_kwargs=tokenization_kwargs,
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
623

624
625
626
    @property
    def tokenizer(self) -> Optional[AnyTokenizer]:
        return self.processor.tokenizer
627

628
629
630
    @tokenizer.setter
    def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
        self.processor.tokenizer = tokenizer
631

632
    async def get_tokenizer(self) -> AnyTokenizer:
633
        if self.tokenizer is None:
634
635
636
            raise ValueError(
                "Unable to get tokenizer because skip_tokenizer_init is True"
            )
637

638
        return self.tokenizer
639
640

    async def is_tracing_enabled(self) -> bool:
641
        return self.observability_config.otlp_traces_endpoint is not None
642

643
    async def do_log_stats(self) -> None:
644
645
        if self.logger_manager:
            self.logger_manager.log()
646
647
648

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
649
650
        if self.errored:
            raise self.dead_error
651
652

    async def start_profile(self) -> None:
653
654
655
656
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
657
658

    async def stop_profile(self) -> None:
659
660
661
662
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
663

664
    async def reset_mm_cache(self) -> None:
665
        self.processor.clear_mm_cache()
666
667
        await self.engine_core.reset_mm_cache_async()

668
    async def reset_prefix_cache(self, device: Optional[Device] = None) -> None:
669
670
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
671
672
        await self.engine_core.reset_prefix_cache_async()

673
    async def sleep(self, level: int = 1) -> None:
674
        await self.reset_prefix_cache()
675
676
        await self.engine_core.sleep_async(level)

677
678
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        await self.engine_core.wake_up_async(tags)
679

680
681
682
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

683
    async def add_lora(self, lora_request: LoRARequest) -> bool:
684
        """Load a new LoRA adapter into the engine for future requests."""
685
686
687
688
689
690
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

691
    async def list_loras(self) -> set[int]:
692
693
694
695
696
697
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
698

699
700
701
702
703
704
705
    async def collective_rpc(
        self,
        method: str,
        timeout: Optional[float] = None,
        args: tuple = (),
        kwargs: Optional[dict] = None,
    ):
706
707
708
709
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
710
711
            method, timeout, args, kwargs
        )
712

713
714
715
716
717
718
719
720
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

721
            logger.info("Engines are still running, waiting for requests to drain...")
722
723
            await asyncio.sleep(1)  # Wait 1 second before checking again

724
725
726
727
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
728

729
730
731
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
732
733
734
735
736
737
738
739
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
740
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
741
        if old_data_parallel_size == new_data_parallel_size:
742
743
744
745
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
746
747
            return
        logger.info(
748
749
750
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
751
752
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
753
754
755
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
756
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
757
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
758
759

        # recreate stat loggers
760
761
762
763
764
765
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
766
                vllm_config=self.vllm_config,
767
                engine_idxs=list(range(new_data_parallel_size)),
768
769
770
                custom_stat_loggers=None,
            )

771
772
    @property
    def is_running(self) -> bool:
773
774
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
775
776
777

    @property
    def is_stopped(self) -> bool:
778
        return self.errored
779
780
781

    @property
    def errored(self) -> bool:
782
        return self.engine_core.resources.engine_dead or not self.is_running
783
784
785

    @property
    def dead_error(self) -> BaseException:
786
        return EngineDeadError()