async_llm.py 33.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
from collections.abc import AsyncGenerator, Iterable, Mapping
8
from copy import copy
9
from typing import Any, cast
10

11
import numpy as np
12
import torch
13
from typing_extensions import deprecated
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
from vllm.engine.arg_utils import AsyncEngineArgs
18
from vllm.engine.protocol import EngineClient
19
from vllm.entrypoints.utils import _validate_truncation_size
20
from vllm.inputs import PromptType
21
22
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
23
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
24
from vllm.outputs import PoolingRequestOutput, RequestOutput
25
from vllm.plugins.io_processors import get_io_processor
26
from vllm.pooling_params import PoolingParams
27
from vllm.sampling_params import SamplingParams
28
from vllm.tasks import SupportedTask
29
from vllm.tokenizers import TokenizerLike
30
from vllm.tracing import init_tracer
31
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
32
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
33
from vllm.usage.usage_lib import UsageContext
34
35
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
36
from vllm.utils.math_utils import cdiv
37
from vllm.v1.engine import EngineCoreRequest
38
from vllm.v1.engine.core_client import EngineCoreClient
39
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
40
from vllm.v1.engine.input_processor import InputProcessor
41
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
42
from vllm.v1.engine.parallel_sampling import ParentRequest
43
from vllm.v1.executor import Executor
44
45
46
47
48
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
49
from vllm.v1.metrics.prometheus import shutdown_prometheus
50
from vllm.v1.metrics.stats import IterationStats
51
52
53
54
55
56
57
58

logger = init_logger(__name__)


class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
59
        executor_class: type[Executor],
60
61
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
62
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
63
64
65
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
66
        stat_loggers: list[StatLoggerFactory] | None = None,
67
        aggregate_engine_logging: bool = False,
68
        client_addresses: dict[str, str] | None = None,
69
        client_count: int = 1,
70
        client_index: int = 0,
71
    ) -> None:
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
92
93
94
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

95
        self.model_config = vllm_config.model_config
96
        self.vllm_config = vllm_config
97
        self.observability_config = vllm_config.observability_config
98
        self.log_requests = log_requests
99

100
101
102
103
104
105
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
106
            logger.info(
107
108
109
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
110
            )
111

112
113
114
115
116
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = init_tokenizer_from_configs(self.model_config)

117
        self.input_processor = InputProcessor(self.vllm_config, tokenizer)
118
119
120
121
        self.io_processor = get_io_processor(
            self.vllm_config,
            self.model_config.io_processor_plugin,
        )
122

123
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
124
        self.output_processor = OutputProcessor(
125
126
127
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
128
        )
129
130
131
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
132
            self.output_processor.tracer = tracer
133
134

        # EngineCore (starts the engine in background process).
135
        self.engine_core = EngineCoreClient.make_async_mp_client(
136
137
            vllm_config=vllm_config,
            executor_class=executor_class,
138
            log_stats=self.log_stats,
139
            client_addresses=client_addresses,
140
            client_count=client_count,
141
            client_index=client_index,
142
        )
143
144

        # Loggers.
145
        self.logger_manager: StatLoggerManager | None = None
146
147
148
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
149
                engine_idxs=self.engine_core.engine_ranks_managed,
150
                custom_stat_loggers=custom_stat_loggers,
151
                enable_default_loggers=log_stats,
152
                client_count=client_count,
153
                aggregate_engine_logging=aggregate_engine_logging,
154
155
156
            )
            self.logger_manager.log_engine_initialized()

157
158
159
160
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False

161
        self.output_handler: asyncio.Task | None = None
162
163
164
165
166
167
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
168

169
170
171
172
        if (
            envs.VLLM_TORCH_PROFILER_DIR
            and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM
        ):
173
174
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
175
176
                envs.VLLM_TORCH_PROFILER_DIR,
            )
177
178
179
180
181
182
183
184
185
            if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0:
                logger.warning_once(
                    "Torch profiler received max_iters or delay_iters setting. These "
                    "are not compatible with the AsyncLLM profiler and will be ignored "
                    "for the AsyncLLM process. Engine process profiling will still "
                    "respect these settings. Consider setting "
                    "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable "
                    "AsyncLLM profiling."
                )
186
187
188
189
190
191
192
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
193
194
195
                    envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True
                ),
            )
196
197
198
        else:
            self.profiler = None

199
200
201
202
203
204
205
206
    @property
    @deprecated(
        "`AsyncLLM.processor` has been renamed to `AsyncLLM.input_processor`. "
        "The old name will be removed in v0.13."
    )
    def processor(self):
        return self.input_processor

207
208
    @classmethod
    def from_vllm_config(
209
210
211
212
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
213
        stat_loggers: list[StatLoggerFactory] | None = None,
214
        enable_log_requests: bool = False,
215
        aggregate_engine_logging: bool = False,
216
        disable_log_stats: bool = False,
217
        client_addresses: dict[str, str] | None = None,
218
219
        client_count: int = 1,
        client_index: int = 0,
220
221
222
223
224
225
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
226
            stat_loggers=stat_loggers,
227
            log_requests=enable_log_requests,
228
            log_stats=not disable_log_stats,
229
            aggregate_engine_logging=aggregate_engine_logging,
230
            usage_context=usage_context,
231
            client_addresses=client_addresses,
232
            client_count=client_count,
233
            client_index=client_index,
234
235
        )

236
237
238
239
240
241
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
242
        stat_loggers: list[StatLoggerFactory] | None = None,
243
    ) -> "AsyncLLM":
244
245
246
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
247
        vllm_config = engine_args.create_engine_config(usage_context)
248
        executor_class = Executor.get_class(vllm_config)
249
250
251
252
253

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
254
            log_requests=engine_args.enable_log_requests,
255
256
257
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
258
            stat_loggers=stat_loggers,
259
260
        )

261
262
263
    def __del__(self):
        self.shutdown()

264
265
266
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

267
268
        shutdown_prometheus()

269
270
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
271

272
273
274
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
275

276
277
278
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

279
280
281
    async def add_request(
        self,
        request_id: str,
282
283
284
285
286
287
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
288
        priority: int = 0,
289
290
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
291
    ) -> RequestOutputCollector:
292
293
        """Add new request to the AsyncLLM."""

294
295
296
        if self.errored:
            raise EngineDeadError()

297
        is_pooling = isinstance(params, PoolingParams)
298
299
300

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
301

302
        # Convert Input --> Request.
303
304
305
306
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
        else:
            assert prompt_text is None
307
            request = self.input_processor.process_inputs(
308
309
310
311
312
313
314
315
316
317
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )
318
319
320
321
            if isinstance(prompt, str):
                prompt_text = prompt
            elif isinstance(prompt, Mapping):
                prompt_text = cast(str | None, prompt.get("prompt"))
322

323
324
325
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

326
        if is_pooling or params.n == 1:
327
            await self._add_request(request, prompt_text, None, 0, queue)
328
329
            return queue

330
331
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
332

333
        # Fan out child requests (for n>1).
334
335
336
        parent_request = ParentRequest(request_id, parent_params)
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
337
            child_request = request if idx == parent_params.n - 1 else copy(request)
338
            child_request.request_id = request_id
339
            child_request.sampling_params = child_params
340
341
342
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
343
        return queue
344

345
346
347
    async def _add_request(
        self,
        request: EngineCoreRequest,
348
349
        prompt: str | None,
        parent_req: ParentRequest | None,
350
351
352
        index: int,
        queue: RequestOutputCollector,
    ):
353
        # Add the request to OutputProcessor (this process).
354
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
355

356
357
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
358

359
360
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
361
362
363
364
365
366

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
367
    async def generate(
368
        self,
369
        prompt: EngineCoreRequest | PromptType,
370
371
        sampling_params: SamplingParams,
        request_id: str,
372
        *,
373
374
375
376
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
377
        priority: int = 0,
378
        data_parallel_rank: int | None = None,
379
380
381
382
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
383
            * 2) Processing the Input.
384
385
386
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

387
388
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
389
390
391
392
393
394
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

395
396
397
398
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and sampling_params.prompt_logprobs
        ):
399
400
401
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
402
403
                "prompt logprobs"
            )
404

405
406
407
408
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
409
            self._run_output_handler()
410

411
412
413
414
            # Wait until generation is resumed if the engine is paused.
            async with self._pause_cond:
                await self._pause_cond.wait_for(lambda: not self._paused)

415
416
417
418
419
420
421
422
423
424
            if tokenization_kwargs is None:
                tokenization_kwargs = {}
                truncate_prompt_tokens = sampling_params.truncate_prompt_tokens

                _validate_truncation_size(
                    self.model_config.max_model_len,
                    truncate_prompt_tokens,
                    tokenization_kwargs,
                )

425
426
427
428
429
430
431
432
433
434
435
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
436

437
438
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
439
440
            finished = False
            while not finished:
441
442
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
443
                out = q.get_nowait() or await q.get()
444

445
                # Note: both OutputProcessor and EngineCore handle their
446
                # own request cleanup based on finished.
447
                finished = out.finished
448
                assert isinstance(out, RequestOutput)
449
450
                yield out

451
        # If the request is disconnected by the client, generate()
452
453
454
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
455
            await self.abort(request_id)
456
457
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
458
            raise
459

460
461
462
463
464
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
465

466
467
468
469
470
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
471

472
        # Unexpected error in the generate() task (possibly recoverable).
473
        except Exception as e:
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
490
        logger_manager = self.logger_manager
491
        input_processor = self.input_processor
492
493
494
495
496
497
498
499

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

500
501
502
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
503
504
505
506

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
507
                    if num_outputs <= envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
508
                        slices = (outputs.outputs,)
509
510
511
                    else:
                        slices = np.array_split(
                            outputs.outputs,
512
                            cdiv(num_outputs, envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE),
513
                        )
514
515
516
517

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
518
519
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
520
521
522
523
524
525
526
527
528
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
529
530
                            processed_outputs.reqs_to_abort
                        )
531

532
533
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

534
535
536
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
537
538
539
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
540
541
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
542
                            mm_cache_stats=input_processor.stat_mm_cache(),
543
544
545
546
547
548
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
549

550
    async def abort(self, request_id: str | Iterable[str]) -> None:
551
        """Abort RequestId in OutputProcessor and EngineCore."""
552

553
554
555
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
556
557
        all_request_ids = self.output_processor.abort_requests(request_ids)
        await self.engine_core.abort_requests_async(all_request_ids)
558

559
        if self.log_requests:
560
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
561

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight
                requests to finish before pausing. When ``False`` (default),
                immediately aborts any in-flight requests.
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).
        """

        async with self._pause_cond:
            if self._paused:
                return
            self._paused = True

        if not wait_for_inflight_requests:
            request_ids = list(self.output_processor.request_states.keys())
            if request_ids:
                await self.abort(request_ids)

        # Wait for running requests to drain before clearing cache.
        if self.output_processor.has_unfinished_requests():
            await self.output_processor.wait_for_requests_to_drain()

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

614
    async def encode(
615
616
617
618
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
619
620
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
621
        priority: int = 0,
622
623
        truncate_prompt_tokens: int | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            self._run_output_handler()

645
646
647
648
            # Respect pause state before accepting new requests.
            async with self._pause_cond:
                await self._pause_cond.wait_for(lambda: not self._paused)

649
            if tokenization_kwargs is None:
650
                tokenization_kwargs = {}
651
652
653
654
655
656
            _validate_truncation_size(
                self.model_config.max_model_len,
                truncate_prompt_tokens,
                tokenization_kwargs,
            )

657
658
659
660
661
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
662
                tokenization_kwargs=tokenization_kwargs,
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
706

707
    @property
708
    def tokenizer(self) -> TokenizerLike | None:
709
        return self.input_processor.tokenizer
710

711
    @tokenizer.setter
712
    def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
713
        self.input_processor.tokenizer = tokenizer
714

715
    async def get_tokenizer(self) -> TokenizerLike:
716
        if self.tokenizer is None:
717
            raise ValueError(
718
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
719
            )
720

721
        return self.tokenizer
722
723

    async def is_tracing_enabled(self) -> bool:
724
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
725

726
    async def do_log_stats(self) -> None:
727
728
        if self.logger_manager:
            self.logger_manager.log()
729
730
731

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
732
733
        if self.errored:
            raise self.dead_error
734
735

    async def start_profile(self) -> None:
736
737
738
739
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
740
741

    async def stop_profile(self) -> None:
742
743
744
745
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
746

747
    async def reset_mm_cache(self) -> None:
748
        self.input_processor.clear_mm_cache()
749
750
        await self.engine_core.reset_mm_cache_async()

751
    async def reset_prefix_cache(self) -> None:
752
753
        await self.engine_core.reset_prefix_cache_async()

754
    async def sleep(self, level: int = 1) -> None:
755
        await self.reset_prefix_cache()
756
757
        await self.engine_core.sleep_async(level)

758
759
760
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

761
    async def wake_up(self, tags: list[str] | None = None) -> None:
762
        await self.engine_core.wake_up_async(tags)
763

764
765
766
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

767
768
769
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

770
    async def add_lora(self, lora_request: LoRARequest) -> bool:
771
        """Load a new LoRA adapter into the engine for future requests."""
772
773
774
775
776
777
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

778
    async def list_loras(self) -> set[int]:
779
780
781
782
783
784
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
785

786
787
788
    async def collective_rpc(
        self,
        method: str,
789
        timeout: float | None = None,
790
        args: tuple = (),
791
        kwargs: dict | None = None,
792
    ):
793
794
795
796
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
797
798
            method, timeout, args, kwargs
        )
799

800
801
802
803
804
805
806
807
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

808
            logger.info("Engines are still running, waiting for requests to drain...")
809
810
            await asyncio.sleep(1)  # Wait 1 second before checking again

811
812
813
814
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
815

816
817
818
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
819
820
821
822
823
824
825
826
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
827
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
828
        if old_data_parallel_size == new_data_parallel_size:
829
830
831
832
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
833
834
            return
        logger.info(
835
836
837
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
838
839
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
840
841
842
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
843
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
844
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
845
846

        # recreate stat loggers
847
848
849
850
851
852
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
853
                vllm_config=self.vllm_config,
854
                engine_idxs=list(range(new_data_parallel_size)),
855
856
857
                custom_stat_loggers=None,
            )

858
859
    @property
    def is_running(self) -> bool:
860
861
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
862
863
864

    @property
    def is_stopped(self) -> bool:
865
        return self.errored
866
867
868

    @property
    def errored(self) -> bool:
869
        return self.engine_core.resources.engine_dead or not self.is_running
870
871
872

    @property
    def dead_error(self) -> BaseException:
873
        return EngineDeadError()