async_llm.py 33.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any, cast
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm.config import VllmConfig
16
from vllm.engine.arg_utils import AsyncEngineArgs
17
from vllm.engine.protocol import EngineClient
18
from vllm.entrypoints.utils import _validate_truncation_size
19
from vllm.inputs import PromptType
20
21
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
22
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
23
from vllm.outputs import PoolingRequestOutput, RequestOutput
24
from vllm.plugins.io_processors import get_io_processor
25
from vllm.pooling_params import PoolingParams
26
from vllm.renderers import RendererLike
27
from vllm.sampling_params import SamplingParams
28
from vllm.tasks import SupportedTask
29
from vllm.tokenizers import TokenizerLike
30
from vllm.tracing import init_tracer
31
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
32
from vllm.usage.usage_lib import UsageContext
33
34
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
35
from vllm.v1.engine import EngineCoreRequest
36
from vllm.v1.engine.core_client import EngineCoreClient
37
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
38
from vllm.v1.engine.input_processor import InputProcessor
39
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
40
from vllm.v1.engine.parallel_sampling import ParentRequest
41
from vllm.v1.executor import Executor
42
43
44
45
46
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
47
from vllm.v1.metrics.prometheus import shutdown_prometheus
48
from vllm.v1.metrics.stats import IterationStats
49
50
51
52
53
54
55
56

logger = init_logger(__name__)


class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
57
        executor_class: type[Executor],
58
59
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
60
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
61
62
63
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
64
        stat_loggers: list[StatLoggerFactory] | None = None,
65
        aggregate_engine_logging: bool = False,
66
        client_addresses: dict[str, str] | None = None,
67
        client_count: int = 1,
68
        client_index: int = 0,
69
    ) -> None:
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
90
91
92
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

93
        self.model_config = vllm_config.model_config
94
        self.vllm_config = vllm_config
95
        self.observability_config = vllm_config.observability_config
96
        self.log_requests = log_requests
97

98
99
100
101
102
103
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
104
            logger.info(
105
106
107
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
108
            )
109

110
        self.input_processor = InputProcessor(self.vllm_config)
111
112
        self.io_processor = get_io_processor(
            self.vllm_config,
113
            self.model_config.io_processor_plugin,
114
        )
115

116
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
117
        self.output_processor = OutputProcessor(
118
119
120
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
121
        )
122
123
124
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
125
            self.output_processor.tracer = tracer
126
127

        # EngineCore (starts the engine in background process).
128
        self.engine_core = EngineCoreClient.make_async_mp_client(
129
130
            vllm_config=vllm_config,
            executor_class=executor_class,
131
            log_stats=self.log_stats,
132
            client_addresses=client_addresses,
133
            client_count=client_count,
134
            client_index=client_index,
135
        )
136
137

        # Loggers.
138
        self.logger_manager: StatLoggerManager | None = None
139
140
141
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
142
                engine_idxs=self.engine_core.engine_ranks_managed,
143
                custom_stat_loggers=custom_stat_loggers,
144
                enable_default_loggers=log_stats,
145
                client_count=client_count,
146
                aggregate_engine_logging=aggregate_engine_logging,
147
148
149
            )
            self.logger_manager.log_engine_initialized()

150
151
152
153
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False

154
        self.output_handler: asyncio.Task | None = None
155
156
157
158
159
160
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
161

162
        if (
163
164
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
165
        ):
166
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
167
168
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
169
                profiler_dir,
170
            )
171
172
173
174
175
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
176
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
177
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
178
                    profiler_dir,
179
                    worker_name=worker_name,
180
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
181
182
                ),
            )
183
184
185
        else:
            self.profiler = None

186
187
    @classmethod
    def from_vllm_config(
188
189
190
191
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
192
        stat_loggers: list[StatLoggerFactory] | None = None,
193
        enable_log_requests: bool = False,
194
        aggregate_engine_logging: bool = False,
195
        disable_log_stats: bool = False,
196
        client_addresses: dict[str, str] | None = None,
197
198
        client_count: int = 1,
        client_index: int = 0,
199
200
201
202
203
204
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
205
            stat_loggers=stat_loggers,
206
            log_requests=enable_log_requests,
207
            log_stats=not disable_log_stats,
208
            aggregate_engine_logging=aggregate_engine_logging,
209
            usage_context=usage_context,
210
            client_addresses=client_addresses,
211
            client_count=client_count,
212
            client_index=client_index,
213
214
        )

215
216
217
218
219
220
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
221
        stat_loggers: list[StatLoggerFactory] | None = None,
222
    ) -> "AsyncLLM":
223
224
225
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
226
        vllm_config = engine_args.create_engine_config(usage_context)
227
        executor_class = Executor.get_class(vllm_config)
228
229
230
231
232

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
233
            log_requests=engine_args.enable_log_requests,
234
235
236
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
237
            stat_loggers=stat_loggers,
238
239
        )

240
241
242
    def __del__(self):
        self.shutdown()

243
244
245
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

246
247
        shutdown_prometheus()

248
249
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
250

251
252
253
        if input_processor := getattr(self, "input_processor", None):
            input_processor.close()

254
255
256
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
257

258
259
260
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

261
262
263
    async def add_request(
        self,
        request_id: str,
264
265
266
267
268
269
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
270
        priority: int = 0,
271
272
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
273
    ) -> RequestOutputCollector:
274
275
        """Add new request to the AsyncLLM."""

276
277
278
        if self.errored:
            raise EngineDeadError()

279
        is_pooling = isinstance(params, PoolingParams)
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

        if tokenization_kwargs is None:
            tokenization_kwargs = {}
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )

300
        # Convert Input --> Request.
301
302
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
303
304
305
306
307
308
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
309
        else:
310
311
312
313
            if prompt_text is not None:
                raise ValueError(
                    "should only provide prompt_text with EngineCoreRequest"
                )
314
            request = self.input_processor.process_inputs(
315
316
317
318
319
320
321
322
323
324
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )
325
326
327
328
            if isinstance(prompt, str):
                prompt_text = prompt
            elif isinstance(prompt, Mapping):
                prompt_text = cast(str | None, prompt.get("prompt"))
329

330
331
        self.input_processor.assign_request_id(request)

332
333
334
335
336
337
338
339
340
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

        # Respect pause state before accepting new requests.
        async with self._pause_cond:
            await self._pause_cond.wait_for(lambda: not self._paused)

341
342
343
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

344
345
346
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

347
        if is_pooling or params.n == 1:
348
            await self._add_request(request, prompt_text, None, 0, queue)
349
350
            return queue

351
352
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
353

354
        # Fan out child requests (for n>1).
355
        parent_request = ParentRequest(request)
356
357
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
358
            child_request = request if idx == parent_params.n - 1 else copy(request)
359
            child_request.request_id = request_id
360
            child_request.sampling_params = child_params
361
362
363
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
364
        return queue
365

366
367
368
    async def _add_request(
        self,
        request: EngineCoreRequest,
369
370
        prompt: str | None,
        parent_req: ParentRequest | None,
371
372
373
        index: int,
        queue: RequestOutputCollector,
    ):
374
        # Add the request to OutputProcessor (this process).
375
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
376

377
378
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
379

380
381
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
382
383
384
385
386
387

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
388
    async def generate(
389
        self,
390
        prompt: EngineCoreRequest | PromptType,
391
392
        sampling_params: SamplingParams,
        request_id: str,
393
        *,
394
395
396
397
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
398
        priority: int = 0,
399
        data_parallel_rank: int | None = None,
400
401
402
403
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
404
            * 2) Processing the Input.
405
406
407
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

408
409
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
410
411
412
413
414
415
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

416
        q: RequestOutputCollector | None = None
417
        try:
418
419
420
421
422
423
424
425
426
427
428
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
429

430
431
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
432
433
            finished = False
            while not finished:
434
435
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
436
                out = q.get_nowait() or await q.get()
437

438
                # Note: both OutputProcessor and EngineCore handle their
439
                # own request cleanup based on finished.
440
                finished = out.finished
441
                assert isinstance(out, RequestOutput)
442
443
                yield out

444
        # If the request is disconnected by the client, generate()
445
446
447
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
448
449
            if q is not None:
                await self.abort(q.request_id, internal=True)
450
451
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
452
            raise
453

454
455
456
457
458
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
459

460
        # Request validation error.
461
        except ValueError as e:
462
            if self.log_requests:
463
                logger.info("Request %s failed (bad request): %s.", request_id, e)
464
            raise
465

466
        # Unexpected error in the generate() task (possibly recoverable).
467
        except Exception as e:
468
469
            if q is not None:
                await self.abort(q.request_id, internal=True)
470
            if self.log_requests:
471
472
473
474
475
476
477
478
479
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
                        + "error during printing an exception of class"
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
480
481
482
483
484
485
486
487
488
489
490
491
492
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
493
        logger_manager = self.logger_manager
494
        input_processor = self.input_processor
495
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
496
497
498
499
500
501
502
503

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

504
505
506
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
507
508
509
510

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
511
512
513
514
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
515
516
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
517
518
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
519
520
521
522
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
523
                        if end < num_outputs:
524
525
526
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
527
528
529
530
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
531

532
533
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

534
535
536
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
537
538
539
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
540
541
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
542
                            mm_cache_stats=input_processor.stat_mm_cache(),
543
544
545
546
547
548
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
549

550
551
552
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
553
        """Abort RequestId in OutputProcessor and EngineCore."""
554

555
556
557
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
558
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
559
        await self.engine_core.abort_requests_async(all_request_ids)
560

561
        if self.log_requests:
562
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
563

564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight
                requests to finish before pausing. When ``False`` (default),
                immediately aborts any in-flight requests.
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).
        """

        async with self._pause_cond:
            if self._paused:
                return
            self._paused = True

        if not wait_for_inflight_requests:
            request_ids = list(self.output_processor.request_states.keys())
            if request_ids:
592
                await self.abort(request_ids, internal=True)
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

        # Wait for running requests to drain before clearing cache.
        if self.output_processor.has_unfinished_requests():
            await self.output_processor.wait_for_requests_to_drain()

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

616
    async def encode(
617
618
619
620
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
621
622
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
623
        priority: int = 0,
624
625
        truncate_prompt_tokens: int | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
626
627
628
629
630
631
632
633
634
635
636
637
638
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
639
640
641

        NOTE: truncate_prompt_tokens is deprecated in v0.14.
        TODO: Remove truncate_prompt_tokens in v0.15.
642
643
        """

644
        q: RequestOutputCollector | None = None
645
        try:
646
647
648
649
650
651
652
653
654
            if truncate_prompt_tokens is not None:
                warnings.warn(
                    "The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
                    "is deprecated and will be removed in v0.15. "
                    "Please use `pooling_params.truncate_prompt_tokens` instead.",
                    DeprecationWarning,
                    stacklevel=2,
                )

655
656
657
658
659
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
660
                tokenization_kwargs=tokenization_kwargs,
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
681
682
            if q is not None:
                await self.abort(q.request_id, internal=True)
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
701
702
            if q is not None:
                await self.abort(q.request_id, internal=True)
703
704
705
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
706

707
    @property
708
    def tokenizer(self) -> TokenizerLike | None:
709
        return self.input_processor.tokenizer
710

711
712
    def get_tokenizer(self) -> TokenizerLike:
        return self.input_processor.get_tokenizer()
713

714
715
716
    @property
    def renderer(self) -> RendererLike:
        return self.input_processor.renderer
717
718

    async def is_tracing_enabled(self) -> bool:
719
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
720

721
    async def do_log_stats(self) -> None:
722
723
        if self.logger_manager:
            self.logger_manager.log()
724
725
726

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
727
728
        if self.errored:
            raise self.dead_error
729
730

    async def start_profile(self) -> None:
731
732
733
734
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
735
736

    async def stop_profile(self) -> None:
737
738
739
740
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
741

742
    async def reset_mm_cache(self) -> None:
743
        self.input_processor.clear_mm_cache()
744
745
        await self.engine_core.reset_mm_cache_async()

746
747
748
749
750
751
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
752

753
    async def sleep(self, level: int = 1) -> None:
754
        await self.reset_prefix_cache()
755
756
        await self.engine_core.sleep_async(level)

757
758
759
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

760
    async def wake_up(self, tags: list[str] | None = None) -> None:
761
        await self.engine_core.wake_up_async(tags)
762

763
764
765
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

766
767
768
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

769
    async def add_lora(self, lora_request: LoRARequest) -> bool:
770
        """Load a new LoRA adapter into the engine for future requests."""
771
772
773
774
775
776
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

777
    async def list_loras(self) -> set[int]:
778
779
780
781
782
783
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
784

785
786
787
    async def collective_rpc(
        self,
        method: str,
788
        timeout: float | None = None,
789
        args: tuple = (),
790
        kwargs: dict | None = None,
791
    ):
792
793
794
795
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
796
797
            method, timeout, args, kwargs
        )
798

799
800
801
802
803
804
805
806
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

807
            logger.info("Engines are still running, waiting for requests to drain...")
808
809
            await asyncio.sleep(1)  # Wait 1 second before checking again

810
811
812
813
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
814

815
816
817
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
818
819
820
821
822
823
824
825
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
826
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
827
        if old_data_parallel_size == new_data_parallel_size:
828
829
830
831
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
832
833
            return
        logger.info(
834
835
836
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
837
838
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
839
840
841
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
842
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
843
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
844
845

        # recreate stat loggers
846
847
848
849
850
851
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
852
                vllm_config=self.vllm_config,
853
                engine_idxs=list(range(new_data_parallel_size)),
854
855
856
                custom_stat_loggers=None,
            )

857
858
    @property
    def is_running(self) -> bool:
859
860
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
861
862
863

    @property
    def is_stopped(self) -> bool:
864
        return self.errored
865
866
867

    @property
    def errored(self) -> bool:
868
        return self.engine_core.resources.engine_dead or not self.is_running
869
870
871

    @property
    def dead_error(self) -> BaseException:
872
        return EngineDeadError()