async_llm.py 33.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
from collections.abc import AsyncGenerator, Iterable, Mapping
8
from copy import copy
9
from typing import Any, cast
10

11
import numpy as np
12
import torch
13
from typing_extensions import deprecated
14

15
import vllm.envs as envs
16
from vllm.config import VllmConfig
17
from vllm.engine.arg_utils import AsyncEngineArgs
18
from vllm.engine.protocol import EngineClient
19
from vllm.entrypoints.utils import _validate_truncation_size
20
from vllm.inputs import PromptType
21
22
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
23
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
24
from vllm.outputs import PoolingRequestOutput, RequestOutput
25
from vllm.plugins.io_processors import get_io_processor
26
from vllm.pooling_params import PoolingParams
27
from vllm.sampling_params import SamplingParams
28
from vllm.tasks import SupportedTask
29
from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config
30
from vllm.tracing import init_tracer
31
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
32
from vllm.usage.usage_lib import UsageContext
33
34
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
35
from vllm.utils.math_utils import cdiv
36
from vllm.v1.engine import EngineCoreRequest
37
from vllm.v1.engine.core_client import EngineCoreClient
38
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
39
from vllm.v1.engine.input_processor import InputProcessor
40
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
41
from vllm.v1.engine.parallel_sampling import ParentRequest
42
from vllm.v1.executor import Executor
43
44
45
46
47
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
48
from vllm.v1.metrics.prometheus import shutdown_prometheus
49
from vllm.v1.metrics.stats import IterationStats
50
51
52
53
54
55
56
57

logger = init_logger(__name__)


class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
58
        executor_class: type[Executor],
59
60
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
61
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
62
63
64
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
65
        stat_loggers: list[StatLoggerFactory] | None = None,
66
        aggregate_engine_logging: bool = False,
67
        client_addresses: dict[str, str] | None = None,
68
        client_count: int = 1,
69
        client_index: int = 0,
70
    ) -> None:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
91
92
93
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

94
        self.model_config = vllm_config.model_config
95
        self.vllm_config = vllm_config
96
        self.observability_config = vllm_config.observability_config
97
        self.log_requests = log_requests
98

99
100
101
102
103
104
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
105
            logger.info(
106
107
108
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
109
            )
110

111
        if self.model_config.skip_tokenizer_init:
112
113
            tokenizer = None
        else:
114
            tokenizer = init_tokenizer_from_config(self.model_config)
115

116
        self.input_processor = InputProcessor(self.vllm_config, tokenizer)
117
118
        self.io_processor = get_io_processor(
            self.vllm_config,
119
            self.model_config.io_processor_plugin,
120
        )
121

122
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
123
        self.output_processor = OutputProcessor(
124
125
126
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
127
        )
128
129
130
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
131
            self.output_processor.tracer = tracer
132
133

        # EngineCore (starts the engine in background process).
134
        self.engine_core = EngineCoreClient.make_async_mp_client(
135
136
            vllm_config=vllm_config,
            executor_class=executor_class,
137
            log_stats=self.log_stats,
138
            client_addresses=client_addresses,
139
            client_count=client_count,
140
            client_index=client_index,
141
        )
142
143

        # Loggers.
144
        self.logger_manager: StatLoggerManager | None = None
145
146
147
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
148
                engine_idxs=self.engine_core.engine_ranks_managed,
149
                custom_stat_loggers=custom_stat_loggers,
150
                enable_default_loggers=log_stats,
151
                client_count=client_count,
152
                aggregate_engine_logging=aggregate_engine_logging,
153
154
155
            )
            self.logger_manager.log_engine_initialized()

156
157
158
159
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False

160
        self.output_handler: asyncio.Task | None = None
161
162
163
164
165
166
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
167

168
169
170
171
        if (
            envs.VLLM_TORCH_PROFILER_DIR
            and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM
        ):
172
173
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
174
175
                envs.VLLM_TORCH_PROFILER_DIR,
            )
176
177
178
179
180
181
182
183
184
            if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0:
                logger.warning_once(
                    "Torch profiler received max_iters or delay_iters setting. These "
                    "are not compatible with the AsyncLLM profiler and will be ignored "
                    "for the AsyncLLM process. Engine process profiling will still "
                    "respect these settings. Consider setting "
                    "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable "
                    "AsyncLLM profiling."
                )
185
186
187
188
189
190
191
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
192
193
194
                    envs.VLLM_TORCH_PROFILER_DIR,
                    worker_name=worker_name,
                    use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP,
195
196
                ),
            )
197
198
199
        else:
            self.profiler = None

200
201
202
203
204
205
206
207
    @property
    @deprecated(
        "`AsyncLLM.processor` has been renamed to `AsyncLLM.input_processor`. "
        "The old name will be removed in v0.13."
    )
    def processor(self):
        return self.input_processor

208
209
    @classmethod
    def from_vllm_config(
210
211
212
213
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
214
        stat_loggers: list[StatLoggerFactory] | None = None,
215
        enable_log_requests: bool = False,
216
        aggregate_engine_logging: bool = False,
217
        disable_log_stats: bool = False,
218
        client_addresses: dict[str, str] | None = None,
219
220
        client_count: int = 1,
        client_index: int = 0,
221
222
223
224
225
226
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
227
            stat_loggers=stat_loggers,
228
            log_requests=enable_log_requests,
229
            log_stats=not disable_log_stats,
230
            aggregate_engine_logging=aggregate_engine_logging,
231
            usage_context=usage_context,
232
            client_addresses=client_addresses,
233
            client_count=client_count,
234
            client_index=client_index,
235
236
        )

237
238
239
240
241
242
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
243
        stat_loggers: list[StatLoggerFactory] | None = None,
244
    ) -> "AsyncLLM":
245
246
247
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
248
        vllm_config = engine_args.create_engine_config(usage_context)
249
        executor_class = Executor.get_class(vllm_config)
250
251
252
253
254

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
255
            log_requests=engine_args.enable_log_requests,
256
257
258
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
259
            stat_loggers=stat_loggers,
260
261
        )

262
263
264
    def __del__(self):
        self.shutdown()

265
266
267
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

268
269
        shutdown_prometheus()

270
271
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
272

273
274
275
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
276

277
278
279
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

280
281
282
    async def add_request(
        self,
        request_id: str,
283
284
285
286
287
288
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
289
        priority: int = 0,
290
291
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
292
    ) -> RequestOutputCollector:
293
294
        """Add new request to the AsyncLLM."""

295
296
297
        if self.errored:
            raise EngineDeadError()

298
        is_pooling = isinstance(params, PoolingParams)
299
300
301

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
302

303
        # Convert Input --> Request.
304
305
306
307
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
        else:
            assert prompt_text is None
308
            request = self.input_processor.process_inputs(
309
310
311
312
313
314
315
316
317
318
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )
319
320
321
322
            if isinstance(prompt, str):
                prompt_text = prompt
            elif isinstance(prompt, Mapping):
                prompt_text = cast(str | None, prompt.get("prompt"))
323

324
325
326
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

327
        if is_pooling or params.n == 1:
328
            await self._add_request(request, prompt_text, None, 0, queue)
329
330
            return queue

331
332
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
333

334
        # Fan out child requests (for n>1).
335
336
337
        parent_request = ParentRequest(request_id, parent_params)
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
338
            child_request = request if idx == parent_params.n - 1 else copy(request)
339
            child_request.request_id = request_id
340
            child_request.sampling_params = child_params
341
342
343
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
344
        return queue
345

346
347
348
    async def _add_request(
        self,
        request: EngineCoreRequest,
349
350
        prompt: str | None,
        parent_req: ParentRequest | None,
351
352
353
        index: int,
        queue: RequestOutputCollector,
    ):
354
        # Add the request to OutputProcessor (this process).
355
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
356

357
358
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
359

360
361
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
362
363
364
365
366
367

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
368
    async def generate(
369
        self,
370
        prompt: EngineCoreRequest | PromptType,
371
372
        sampling_params: SamplingParams,
        request_id: str,
373
        *,
374
375
376
377
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
378
        priority: int = 0,
379
        data_parallel_rank: int | None = None,
380
381
382
383
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
384
            * 2) Processing the Input.
385
386
387
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

388
389
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
390
391
392
393
394
395
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

396
397
398
399
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and sampling_params.prompt_logprobs
        ):
400
401
402
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
403
404
                "prompt logprobs"
            )
405

406
407
408
409
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
410
            self._run_output_handler()
411

412
413
414
415
            # Wait until generation is resumed if the engine is paused.
            async with self._pause_cond:
                await self._pause_cond.wait_for(lambda: not self._paused)

416
417
418
419
420
421
422
423
424
425
            if tokenization_kwargs is None:
                tokenization_kwargs = {}
                truncate_prompt_tokens = sampling_params.truncate_prompt_tokens

                _validate_truncation_size(
                    self.model_config.max_model_len,
                    truncate_prompt_tokens,
                    tokenization_kwargs,
                )

426
427
428
429
430
431
432
433
434
435
436
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
437

438
439
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
440
441
            finished = False
            while not finished:
442
443
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
444
                out = q.get_nowait() or await q.get()
445

446
                # Note: both OutputProcessor and EngineCore handle their
447
                # own request cleanup based on finished.
448
                finished = out.finished
449
                assert isinstance(out, RequestOutput)
450
451
                yield out

452
        # If the request is disconnected by the client, generate()
453
454
455
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
456
            await self.abort(request_id)
457
458
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
459
            raise
460

461
462
463
464
465
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
466

467
468
469
470
471
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
472

473
        # Unexpected error in the generate() task (possibly recoverable).
474
        except Exception as e:
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
491
        logger_manager = self.logger_manager
492
        input_processor = self.input_processor
493
494
495
496
497
498
499
500

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

501
502
503
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
504
505
506
507

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
508
                    if num_outputs <= envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
509
                        slices = (outputs.outputs,)
510
511
512
                    else:
                        slices = np.array_split(
                            outputs.outputs,
513
                            cdiv(num_outputs, envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE),
514
                        )
515
516
517
518

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
519
520
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
521
522
523
524
525
526
527
528
529
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
530
531
                            processed_outputs.reqs_to_abort
                        )
532

533
534
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

535
536
537
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
538
539
540
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
541
542
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
543
                            mm_cache_stats=input_processor.stat_mm_cache(),
544
545
546
547
548
549
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
550

551
    async def abort(self, request_id: str | Iterable[str]) -> None:
552
        """Abort RequestId in OutputProcessor and EngineCore."""
553

554
555
556
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
557
558
        all_request_ids = self.output_processor.abort_requests(request_ids)
        await self.engine_core.abort_requests_async(all_request_ids)
559

560
        if self.log_requests:
561
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight
                requests to finish before pausing. When ``False`` (default),
                immediately aborts any in-flight requests.
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).
        """

        async with self._pause_cond:
            if self._paused:
                return
            self._paused = True

        if not wait_for_inflight_requests:
            request_ids = list(self.output_processor.request_states.keys())
            if request_ids:
                await self.abort(request_ids)

        # Wait for running requests to drain before clearing cache.
        if self.output_processor.has_unfinished_requests():
            await self.output_processor.wait_for_requests_to_drain()

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

615
    async def encode(
616
617
618
619
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
620
621
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
622
        priority: int = 0,
623
624
        truncate_prompt_tokens: int | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            self._run_output_handler()

646
647
648
649
            # Respect pause state before accepting new requests.
            async with self._pause_cond:
                await self._pause_cond.wait_for(lambda: not self._paused)

650
            if tokenization_kwargs is None:
651
                tokenization_kwargs = {}
652
653
654
655
656
657
            _validate_truncation_size(
                self.model_config.max_model_len,
                truncate_prompt_tokens,
                tokenization_kwargs,
            )

658
659
660
661
662
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
663
                tokenization_kwargs=tokenization_kwargs,
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
707

708
    @property
709
    def tokenizer(self) -> TokenizerLike | None:
710
        return self.input_processor.tokenizer
711

712
    @tokenizer.setter
713
    def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
714
        self.input_processor.tokenizer = tokenizer
715

716
    async def get_tokenizer(self) -> TokenizerLike:
717
        if self.tokenizer is None:
718
            raise ValueError(
719
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
720
            )
721

722
        return self.tokenizer
723
724

    async def is_tracing_enabled(self) -> bool:
725
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
726

727
    async def do_log_stats(self) -> None:
728
729
        if self.logger_manager:
            self.logger_manager.log()
730
731
732

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
733
734
        if self.errored:
            raise self.dead_error
735
736

    async def start_profile(self) -> None:
737
738
739
740
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
741
742

    async def stop_profile(self) -> None:
743
744
745
746
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
747

748
    async def reset_mm_cache(self) -> None:
749
        self.input_processor.clear_mm_cache()
750
751
        await self.engine_core.reset_mm_cache_async()

752
753
754
755
756
757
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
758

759
    async def sleep(self, level: int = 1) -> None:
760
        await self.reset_prefix_cache()
761
762
        await self.engine_core.sleep_async(level)

763
764
765
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

766
    async def wake_up(self, tags: list[str] | None = None) -> None:
767
        await self.engine_core.wake_up_async(tags)
768

769
770
771
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

772
773
774
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

775
    async def add_lora(self, lora_request: LoRARequest) -> bool:
776
        """Load a new LoRA adapter into the engine for future requests."""
777
778
779
780
781
782
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

783
    async def list_loras(self) -> set[int]:
784
785
786
787
788
789
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
790

791
792
793
    async def collective_rpc(
        self,
        method: str,
794
        timeout: float | None = None,
795
        args: tuple = (),
796
        kwargs: dict | None = None,
797
    ):
798
799
800
801
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
802
803
            method, timeout, args, kwargs
        )
804

805
806
807
808
809
810
811
812
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

813
            logger.info("Engines are still running, waiting for requests to drain...")
814
815
            await asyncio.sleep(1)  # Wait 1 second before checking again

816
817
818
819
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
820

821
822
823
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
824
825
826
827
828
829
830
831
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
832
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
833
        if old_data_parallel_size == new_data_parallel_size:
834
835
836
837
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
838
839
            return
        logger.info(
840
841
842
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
843
844
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
845
846
847
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
848
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
849
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
850
851

        # recreate stat loggers
852
853
854
855
856
857
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
858
                vllm_config=self.vllm_config,
859
                engine_idxs=list(range(new_data_parallel_size)),
860
861
862
                custom_stat_loggers=None,
            )

863
864
    @property
    def is_running(self) -> bool:
865
866
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
867
868
869

    @property
    def is_stopped(self) -> bool:
870
        return self.errored
871
872
873

    @property
    def errored(self) -> bool:
874
        return self.engine_core.resources.engine_dead or not self.is_running
875
876
877

    @property
    def dead_error(self) -> BaseException:
878
        return EngineDeadError()