async_llm.py 33.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any, cast
11

12
import numpy as np
13
import torch
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
from vllm.engine.arg_utils import AsyncEngineArgs
18
from vllm.engine.protocol import EngineClient
19
from vllm.entrypoints.utils import _validate_truncation_size
20
from vllm.inputs import PromptType
21
22
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
23
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
24
from vllm.outputs import PoolingRequestOutput, RequestOutput
25
from vllm.plugins.io_processors import get_io_processor
26
from vllm.pooling_params import PoolingParams
27
from vllm.sampling_params import SamplingParams
28
from vllm.tasks import SupportedTask
29
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
30
from vllm.tracing import init_tracer
31
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
32
from vllm.usage.usage_lib import UsageContext
33
34
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
35
from vllm.utils.math_utils import cdiv
36
from vllm.v1.engine import EngineCoreRequest
37
from vllm.v1.engine.core_client import EngineCoreClient
38
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
39
from vllm.v1.engine.input_processor import InputProcessor
40
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
41
from vllm.v1.engine.parallel_sampling import ParentRequest
42
from vllm.v1.executor import Executor
43
44
45
46
47
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
48
from vllm.v1.metrics.prometheus import shutdown_prometheus
49
from vllm.v1.metrics.stats import IterationStats
50
51
52
53
54
55
56
57

logger = init_logger(__name__)


class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
58
        executor_class: type[Executor],
59
60
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
61
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
62
63
64
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
65
        stat_loggers: list[StatLoggerFactory] | None = None,
66
        aggregate_engine_logging: bool = False,
67
        client_addresses: dict[str, str] | None = None,
68
        client_count: int = 1,
69
        client_index: int = 0,
70
    ) -> None:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
91
92
93
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

94
        self.model_config = vllm_config.model_config
95
        self.vllm_config = vllm_config
96
        self.observability_config = vllm_config.observability_config
97
        self.log_requests = log_requests
98

99
100
101
102
103
104
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
105
            logger.info(
106
107
108
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
109
            )
110

111
        if self.model_config.skip_tokenizer_init:
112
113
            tokenizer = None
        else:
114
            tokenizer = cached_tokenizer_from_config(self.model_config)
115

116
        self.input_processor = InputProcessor(self.vllm_config, tokenizer)
117
118
        self.io_processor = get_io_processor(
            self.vllm_config,
119
            self.model_config.io_processor_plugin,
120
        )
121

122
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
123
        self.output_processor = OutputProcessor(
124
125
126
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
127
        )
128
129
130
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
131
            self.output_processor.tracer = tracer
132
133

        # EngineCore (starts the engine in background process).
134
        self.engine_core = EngineCoreClient.make_async_mp_client(
135
136
            vllm_config=vllm_config,
            executor_class=executor_class,
137
            log_stats=self.log_stats,
138
            client_addresses=client_addresses,
139
            client_count=client_count,
140
            client_index=client_index,
141
        )
142
143

        # Loggers.
144
        self.logger_manager: StatLoggerManager | None = None
145
146
147
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
148
                engine_idxs=self.engine_core.engine_ranks_managed,
149
                custom_stat_loggers=custom_stat_loggers,
150
                enable_default_loggers=log_stats,
151
                client_count=client_count,
152
                aggregate_engine_logging=aggregate_engine_logging,
153
154
155
            )
            self.logger_manager.log_engine_initialized()

156
157
158
159
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False

160
        self.output_handler: asyncio.Task | None = None
161
162
163
164
165
166
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
167

168
        if (
169
170
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
171
        ):
172
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
173
174
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
175
                profiler_dir,
176
            )
177
178
179
180
181
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
182
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
183
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
184
                    profiler_dir,
185
                    worker_name=worker_name,
186
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
187
188
                ),
            )
189
190
191
        else:
            self.profiler = None

192
193
    @classmethod
    def from_vllm_config(
194
195
196
197
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
198
        stat_loggers: list[StatLoggerFactory] | None = None,
199
        enable_log_requests: bool = False,
200
        aggregate_engine_logging: bool = False,
201
        disable_log_stats: bool = False,
202
        client_addresses: dict[str, str] | None = None,
203
204
        client_count: int = 1,
        client_index: int = 0,
205
206
207
208
209
210
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
211
            stat_loggers=stat_loggers,
212
            log_requests=enable_log_requests,
213
            log_stats=not disable_log_stats,
214
            aggregate_engine_logging=aggregate_engine_logging,
215
            usage_context=usage_context,
216
            client_addresses=client_addresses,
217
            client_count=client_count,
218
            client_index=client_index,
219
220
        )

221
222
223
224
225
226
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
227
        stat_loggers: list[StatLoggerFactory] | None = None,
228
    ) -> "AsyncLLM":
229
230
231
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
232
        vllm_config = engine_args.create_engine_config(usage_context)
233
        executor_class = Executor.get_class(vllm_config)
234
235
236
237
238

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
239
            log_requests=engine_args.enable_log_requests,
240
241
242
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
243
            stat_loggers=stat_loggers,
244
245
        )

246
247
248
    def __del__(self):
        self.shutdown()

249
250
251
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

252
253
        shutdown_prometheus()

254
255
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
256

257
258
259
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
260

261
262
263
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

264
265
266
    async def add_request(
        self,
        request_id: str,
267
268
269
270
271
272
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
273
        priority: int = 0,
274
275
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
276
    ) -> RequestOutputCollector:
277
278
        """Add new request to the AsyncLLM."""

279
280
281
        if self.errored:
            raise EngineDeadError()

282
        is_pooling = isinstance(params, PoolingParams)
283

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

        if tokenization_kwargs is None:
            tokenization_kwargs = {}
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )

303
        # Convert Input --> Request.
304
305
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
306
307
308
309
310
311
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
312
        else:
313
314
315
316
            if prompt_text is not None:
                raise ValueError(
                    "should only provide prompt_text with EngineCoreRequest"
                )
317
            request = self.input_processor.process_inputs(
318
319
320
321
322
323
324
325
326
327
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )
328
329
330
331
            if isinstance(prompt, str):
                prompt_text = prompt
            elif isinstance(prompt, Mapping):
                prompt_text = cast(str | None, prompt.get("prompt"))
332

333
334
        self.input_processor.assign_request_id(request)

335
336
337
338
339
340
341
342
343
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

        # Respect pause state before accepting new requests.
        async with self._pause_cond:
            await self._pause_cond.wait_for(lambda: not self._paused)

344
345
346
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

347
348
349
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

350
        if is_pooling or params.n == 1:
351
            await self._add_request(request, prompt_text, None, 0, queue)
352
353
            return queue

354
355
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
356

357
        # Fan out child requests (for n>1).
358
        parent_request = ParentRequest(request)
359
360
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
361
            child_request = request if idx == parent_params.n - 1 else copy(request)
362
            child_request.request_id = request_id
363
            child_request.sampling_params = child_params
364
365
366
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
367
        return queue
368

369
370
371
    async def _add_request(
        self,
        request: EngineCoreRequest,
372
373
        prompt: str | None,
        parent_req: ParentRequest | None,
374
375
376
        index: int,
        queue: RequestOutputCollector,
    ):
377
        # Add the request to OutputProcessor (this process).
378
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
379

380
381
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
382

383
384
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
385
386
387
388
389
390

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
391
    async def generate(
392
        self,
393
        prompt: EngineCoreRequest | PromptType,
394
395
        sampling_params: SamplingParams,
        request_id: str,
396
        *,
397
398
399
400
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
401
        priority: int = 0,
402
        data_parallel_rank: int | None = None,
403
404
405
406
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
407
            * 2) Processing the Input.
408
409
410
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

411
412
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
413
414
415
416
417
418
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

419
        q: RequestOutputCollector | None = None
420
        try:
421
422
423
424
425
426
427
428
429
430
431
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
432

433
434
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
435
436
            finished = False
            while not finished:
437
438
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
439
                out = q.get_nowait() or await q.get()
440

441
                # Note: both OutputProcessor and EngineCore handle their
442
                # own request cleanup based on finished.
443
                finished = out.finished
444
                assert isinstance(out, RequestOutput)
445
446
                yield out

447
        # If the request is disconnected by the client, generate()
448
449
450
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
451
452
            if q is not None:
                await self.abort(q.request_id, internal=True)
453
454
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
455
            raise
456

457
458
459
460
461
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
462

463
464
465
466
467
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
468

469
        # Unexpected error in the generate() task (possibly recoverable).
470
        except Exception as e:
471
472
            if q is not None:
                await self.abort(q.request_id, internal=True)
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
488
        logger_manager = self.logger_manager
489
        input_processor = self.input_processor
490
491
492
493
494
495
496
497

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

498
499
500
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
501
502
503
504

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
505
                    if num_outputs <= envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
506
                        slices = (outputs.outputs,)
507
508
509
                    else:
                        slices = np.array_split(
                            outputs.outputs,
510
                            cdiv(num_outputs, envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE),
511
                        )
512
513
514
515

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
516
517
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
518
519
520
521
522
523
524
525
526
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
527
528
                            processed_outputs.reqs_to_abort
                        )
529

530
531
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

532
533
534
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
535
536
537
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
538
539
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
540
                            mm_cache_stats=input_processor.stat_mm_cache(),
541
542
543
544
545
546
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
547

548
549
550
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
551
        """Abort RequestId in OutputProcessor and EngineCore."""
552

553
554
555
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
556
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
557
        await self.engine_core.abort_requests_async(all_request_ids)
558

559
        if self.log_requests:
560
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
561

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight
                requests to finish before pausing. When ``False`` (default),
                immediately aborts any in-flight requests.
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).
        """

        async with self._pause_cond:
            if self._paused:
                return
            self._paused = True

        if not wait_for_inflight_requests:
            request_ids = list(self.output_processor.request_states.keys())
            if request_ids:
590
                await self.abort(request_ids, internal=True)
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613

        # Wait for running requests to drain before clearing cache.
        if self.output_processor.has_unfinished_requests():
            await self.output_processor.wait_for_requests_to_drain()

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

614
    async def encode(
615
616
617
618
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
619
620
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
621
        priority: int = 0,
622
623
        truncate_prompt_tokens: int | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
624
625
626
627
628
629
630
631
632
633
634
635
636
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
637
638
639

        NOTE: truncate_prompt_tokens is deprecated in v0.14.
        TODO: Remove truncate_prompt_tokens in v0.15.
640
641
        """

642
        q: RequestOutputCollector | None = None
643
        try:
644
645
646
647
648
649
650
651
652
            if truncate_prompt_tokens is not None:
                warnings.warn(
                    "The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
                    "is deprecated and will be removed in v0.15. "
                    "Please use `pooling_params.truncate_prompt_tokens` instead.",
                    DeprecationWarning,
                    stacklevel=2,
                )

653
654
655
656
657
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
658
                tokenization_kwargs=tokenization_kwargs,
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
679
680
            if q is not None:
                await self.abort(q.request_id, internal=True)
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
699
700
            if q is not None:
                await self.abort(q.request_id, internal=True)
701
702
703
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
704

705
    @property
706
    def tokenizer(self) -> TokenizerLike | None:
707
        return self.input_processor.tokenizer
708

709
    async def get_tokenizer(self) -> TokenizerLike:
710
        if self.tokenizer is None:
711
            raise ValueError(
712
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
713
            )
714

715
        return self.tokenizer
716
717

    async def is_tracing_enabled(self) -> bool:
718
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
719

720
    async def do_log_stats(self) -> None:
721
722
        if self.logger_manager:
            self.logger_manager.log()
723
724
725

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
726
727
        if self.errored:
            raise self.dead_error
728
729

    async def start_profile(self) -> None:
730
731
732
733
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
734
735

    async def stop_profile(self) -> None:
736
737
738
739
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
740

741
    async def reset_mm_cache(self) -> None:
742
        self.input_processor.clear_mm_cache()
743
744
        await self.engine_core.reset_mm_cache_async()

745
746
747
748
749
750
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
751

752
    async def sleep(self, level: int = 1) -> None:
753
        await self.reset_prefix_cache()
754
755
        await self.engine_core.sleep_async(level)

756
757
758
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

759
    async def wake_up(self, tags: list[str] | None = None) -> None:
760
        await self.engine_core.wake_up_async(tags)
761

762
763
764
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

765
766
767
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

768
    async def add_lora(self, lora_request: LoRARequest) -> bool:
769
        """Load a new LoRA adapter into the engine for future requests."""
770
771
772
773
774
775
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

776
    async def list_loras(self) -> set[int]:
777
778
779
780
781
782
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
783

784
785
786
    async def collective_rpc(
        self,
        method: str,
787
        timeout: float | None = None,
788
        args: tuple = (),
789
        kwargs: dict | None = None,
790
    ):
791
792
793
794
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
795
796
            method, timeout, args, kwargs
        )
797

798
799
800
801
802
803
804
805
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

806
            logger.info("Engines are still running, waiting for requests to drain...")
807
808
            await asyncio.sleep(1)  # Wait 1 second before checking again

809
810
811
812
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
813

814
815
816
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
817
818
819
820
821
822
823
824
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
825
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
826
        if old_data_parallel_size == new_data_parallel_size:
827
828
829
830
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
831
832
            return
        logger.info(
833
834
835
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
836
837
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
838
839
840
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
841
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
842
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
843
844

        # recreate stat loggers
845
846
847
848
849
850
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
851
                vllm_config=self.vllm_config,
852
                engine_idxs=list(range(new_data_parallel_size)),
853
854
855
                custom_stat_loggers=None,
            )

856
857
    @property
    def is_running(self) -> bool:
858
859
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
860
861
862

    @property
    def is_stopped(self) -> bool:
863
        return self.errored
864
865
866

    @property
    def errored(self) -> bool:
867
        return self.engine_core.resources.engine_dead or not self.is_running
868
869
870

    @property
    def dead_error(self) -> BaseException:
871
        return EngineDeadError()