async_llm.py 38.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
from vllm.engine.arg_utils import AsyncEngineArgs
18
from vllm.engine.protocol import EngineClient
19
from vllm.entrypoints.utils import _validate_truncation_size
20
from vllm.inputs import PromptType
21
from vllm.inputs.data import StreamingInput
22
23
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
24
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
25
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
26
from vllm.plugins.io_processors import get_io_processor
27
from vllm.pooling_params import PoolingParams
28
from vllm.renderers import RendererLike
29
from vllm.sampling_params import RequestOutputKind, SamplingParams
30
from vllm.tasks import SupportedTask
31
from vllm.tokenizers import TokenizerLike
32
from vllm.tracing import init_tracer
33
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
34
from vllm.usage.usage_lib import UsageContext
35
36
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
37
from vllm.v1.engine import EngineCoreRequest
38
from vllm.v1.engine.core_client import EngineCoreClient
39
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
40
from vllm.v1.engine.input_processor import InputProcessor
41
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
42
from vllm.v1.engine.parallel_sampling import ParentRequest
43
from vllm.v1.engine.utils import get_prompt_text
44
from vllm.v1.executor import Executor
45
46
47
48
49
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
50
from vllm.v1.metrics.prometheus import shutdown_prometheus
51
from vllm.v1.metrics.stats import IterationStats
52
53
54
55

logger = init_logger(__name__)


56
57
58
59
60
61
62
63
64
65
66
67
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


68
69
70
71
class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
72
        executor_class: type[Executor],
73
74
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
75
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
76
77
78
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
79
        stat_loggers: list[StatLoggerFactory] | None = None,
80
        aggregate_engine_logging: bool = False,
81
        client_addresses: dict[str, str] | None = None,
82
        client_count: int = 1,
83
        client_index: int = 0,
84
    ) -> None:
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
105
106
107
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

108
        self.model_config = vllm_config.model_config
109
        self.vllm_config = vllm_config
110
        self.observability_config = vllm_config.observability_config
111
        self.log_requests = log_requests
112

113
114
115
116
117
118
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
119
            logger.info(
120
121
122
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
123
            )
124

125
        self.input_processor = InputProcessor(self.vllm_config)
126
127
        self.io_processor = get_io_processor(
            self.vllm_config,
128
            self.model_config.io_processor_plugin,
129
        )
130

131
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
132
        self.output_processor = OutputProcessor(
133
134
135
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
136
        )
137
138
139
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
140
            self.output_processor.tracer = tracer
141
142

        # EngineCore (starts the engine in background process).
143
        self.engine_core = EngineCoreClient.make_async_mp_client(
144
145
            vllm_config=vllm_config,
            executor_class=executor_class,
146
            log_stats=self.log_stats,
147
            client_addresses=client_addresses,
148
            client_count=client_count,
149
            client_index=client_index,
150
        )
151
152

        # Loggers.
153
        self.logger_manager: StatLoggerManager | None = None
154
155
156
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
157
                engine_idxs=self.engine_core.engine_ranks_managed,
158
                custom_stat_loggers=custom_stat_loggers,
159
                enable_default_loggers=log_stats,
160
                client_count=client_count,
161
                aggregate_engine_logging=aggregate_engine_logging,
162
163
164
            )
            self.logger_manager.log_engine_initialized()

165
166
167
168
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False

169
        self.output_handler: asyncio.Task | None = None
170
171
172
173
174
175
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
176

177
        if (
178
179
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
180
        ):
181
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
182
183
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
184
                profiler_dir,
185
            )
186
187
188
189
190
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
191
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
192
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
193
                    profiler_dir,
194
                    worker_name=worker_name,
195
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
196
197
                ),
            )
198
199
200
        else:
            self.profiler = None

201
202
    @classmethod
    def from_vllm_config(
203
204
205
206
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
207
        stat_loggers: list[StatLoggerFactory] | None = None,
208
        enable_log_requests: bool = False,
209
        aggregate_engine_logging: bool = False,
210
        disable_log_stats: bool = False,
211
        client_addresses: dict[str, str] | None = None,
212
213
        client_count: int = 1,
        client_index: int = 0,
214
215
216
217
218
219
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
220
            stat_loggers=stat_loggers,
221
            log_requests=enable_log_requests,
222
            log_stats=not disable_log_stats,
223
            aggregate_engine_logging=aggregate_engine_logging,
224
            usage_context=usage_context,
225
            client_addresses=client_addresses,
226
            client_count=client_count,
227
            client_index=client_index,
228
229
        )

230
231
232
233
234
235
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
236
        stat_loggers: list[StatLoggerFactory] | None = None,
237
    ) -> "AsyncLLM":
238
239
240
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
241
        vllm_config = engine_args.create_engine_config(usage_context)
242
        executor_class = Executor.get_class(vllm_config)
243
244
245
246
247

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
248
            log_requests=engine_args.enable_log_requests,
249
250
251
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
252
            stat_loggers=stat_loggers,
253
254
        )

255
256
257
    def __del__(self):
        self.shutdown()

258
259
260
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

261
262
        shutdown_prometheus()

263
264
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
265

266
267
268
        if input_processor := getattr(self, "input_processor", None):
            input_processor.close()

269
270
271
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
272

273
274
275
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

276
277
278
    async def add_request(
        self,
        request_id: str,
279
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
280
281
282
283
284
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
285
        priority: int = 0,
286
287
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
288
    ) -> RequestOutputCollector:
289
290
        """Add new request to the AsyncLLM."""

291
292
293
        if self.errored:
            raise EngineDeadError()

294
        is_pooling = isinstance(params, PoolingParams)
295

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

        if tokenization_kwargs is None:
            tokenization_kwargs = {}
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )

315
316
317
318
319
320
321
322
323
324
325
326
327
328
        if isinstance(prompt, AsyncGenerator):
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

329
        # Convert Input --> Request.
330
331
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
332
333
334
335
336
337
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
338
        else:
339
340
341
342
            if prompt_text is not None:
                raise ValueError(
                    "should only provide prompt_text with EngineCoreRequest"
                )
343
            request = self.input_processor.process_inputs(
344
345
346
347
348
349
350
351
352
353
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )
354
            prompt_text = get_prompt_text(prompt)
355

356
357
        self.input_processor.assign_request_id(request)

358
359
360
361
362
363
364
365
366
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

        # Respect pause state before accepting new requests.
        async with self._pause_cond:
            await self._pause_cond.wait_for(lambda: not self._paused)

367
368
369
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

370
371
372
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

373
        if is_pooling or params.n == 1:
374
            await self._add_request(request, prompt_text, None, 0, queue)
375
376
            return queue

377
378
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
379

380
        # Fan out child requests (for n>1).
381
        parent_request = ParentRequest(request)
382
383
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
384
            child_request = request if idx == parent_params.n - 1 else copy(request)
385
            child_request.request_id = request_id
386
            child_request.sampling_params = child_params
387
388
389
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
390
        return queue
391

392
393
394
    async def _add_request(
        self,
        request: EngineCoreRequest,
395
396
        prompt: str | None,
        parent_req: ParentRequest | None,
397
398
399
        index: int,
        queue: RequestOutputCollector,
    ):
400
        # Add the request to OutputProcessor (this process).
401
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
402

403
404
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
405

406
407
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
408

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
                    prompt_text = get_prompt_text(input_chunk.prompt)
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

507
508
509
510
511
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
512
    async def generate(
513
        self,
514
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
515
516
        sampling_params: SamplingParams,
        request_id: str,
517
        *,
518
519
520
521
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
522
        priority: int = 0,
523
        data_parallel_rank: int | None = None,
524
525
526
527
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
528
            * 2) Processing the Input.
529
530
531
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

532
533
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
534
535
536
537
538
539
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

540
        q: RequestOutputCollector | None = None
541
        try:
542
543
544
545
546
547
548
549
550
551
552
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
553

554
555
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
556
557
            finished = False
            while not finished:
558
559
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
560
                out = q.get_nowait() or await q.get()
561

562
                # Note: both OutputProcessor and EngineCore handle their
563
                # own request cleanup based on finished.
564
                assert isinstance(out, RequestOutput)
565
566
567
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
568

569
        # If the request is disconnected by the client, generate()
570
571
572
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
573
574
            if q is not None:
                await self.abort(q.request_id, internal=True)
575
576
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
577
            raise
578

579
580
581
582
583
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
584

585
        # Request validation error.
586
        except ValueError as e:
587
            if self.log_requests:
588
                logger.info("Request %s failed (bad request): %s.", request_id, e)
589
            raise
590

591
592
593
594
595
596
597
598
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

599
        # Unexpected error in the generate() task (possibly recoverable).
600
        except Exception as e:
601
602
            if q is not None:
                await self.abort(q.request_id, internal=True)
603
            if self.log_requests:
604
605
606
607
608
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
609
                        "error during printing an exception of class"
610
611
612
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
613
            raise EngineGenerateError() from e
614
615
616
        finally:
            if q is not None:
                q.close()
617
618
619
620
621
622
623
624
625
626
627
628

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
629
        logger_manager = self.logger_manager
630
        input_processor = self.input_processor
631
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
632
633
634
635
636
637
638
639

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

640
641
642
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
643
644
645
646

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
647
648
649
650
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
651
652
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
653
654
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
655
656
657
658
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
659
                        if end < num_outputs:
660
661
662
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
663
664
665
666
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
667

668
669
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

670
671
672
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
673
674
675
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
676
677
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
678
                            mm_cache_stats=input_processor.stat_mm_cache(),
679
680
681
682
683
684
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
685

686
687
688
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
689
        """Abort RequestId in OutputProcessor and EngineCore."""
690

691
692
693
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
694
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
695
        await self.engine_core.abort_requests_async(all_request_ids)
696

697
        if self.log_requests:
698
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
699

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight
                requests to finish before pausing. When ``False`` (default),
                immediately aborts any in-flight requests.
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).
        """

        async with self._pause_cond:
            if self._paused:
                return
            self._paused = True

        if not wait_for_inflight_requests:
            request_ids = list(self.output_processor.request_states.keys())
            if request_ids:
728
                await self.abort(request_ids, internal=True)
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751

        # Wait for running requests to drain before clearing cache.
        if self.output_processor.has_unfinished_requests():
            await self.output_processor.wait_for_requests_to_drain()

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

752
    async def encode(
753
754
755
756
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
757
758
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
759
        priority: int = 0,
760
761
        truncate_prompt_tokens: int | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
762
763
764
765
766
767
768
769
770
771
772
773
774
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
775
776
777

        NOTE: truncate_prompt_tokens is deprecated in v0.14.
        TODO: Remove truncate_prompt_tokens in v0.15.
778
779
        """

780
        q: RequestOutputCollector | None = None
781
        try:
782
783
784
785
786
787
788
789
790
            if truncate_prompt_tokens is not None:
                warnings.warn(
                    "The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
                    "is deprecated and will be removed in v0.15. "
                    "Please use `pooling_params.truncate_prompt_tokens` instead.",
                    DeprecationWarning,
                    stacklevel=2,
                )

791
792
793
794
795
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
796
                tokenization_kwargs=tokenization_kwargs,
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
817
818
            if q is not None:
                await self.abort(q.request_id, internal=True)
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
837
838
            if q is not None:
                await self.abort(q.request_id, internal=True)
839
840
841
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
842
843
844
        finally:
            if q is not None:
                q.close()
845

846
    @property
847
    def tokenizer(self) -> TokenizerLike | None:
848
        return self.input_processor.tokenizer
849

850
851
    def get_tokenizer(self) -> TokenizerLike:
        return self.input_processor.get_tokenizer()
852

853
854
855
    @property
    def renderer(self) -> RendererLike:
        return self.input_processor.renderer
856
857

    async def is_tracing_enabled(self) -> bool:
858
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
859

860
    async def do_log_stats(self) -> None:
861
862
        if self.logger_manager:
            self.logger_manager.log()
863
864
865

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
866
867
        if self.errored:
            raise self.dead_error
868
869

    async def start_profile(self) -> None:
870
871
872
873
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
874
875

    async def stop_profile(self) -> None:
876
877
878
879
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
880

881
    async def reset_mm_cache(self) -> None:
882
        self.input_processor.clear_mm_cache()
883
884
        await self.engine_core.reset_mm_cache_async()

885
886
887
888
889
890
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
891

892
    async def sleep(self, level: int = 1) -> None:
893
        await self.reset_prefix_cache()
894
895
        await self.engine_core.sleep_async(level)

896
897
898
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

899
    async def wake_up(self, tags: list[str] | None = None) -> None:
900
        await self.engine_core.wake_up_async(tags)
901

902
903
904
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

905
906
907
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

908
    async def add_lora(self, lora_request: LoRARequest) -> bool:
909
        """Load a new LoRA adapter into the engine for future requests."""
910
911
912
913
914
915
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

916
    async def list_loras(self) -> set[int]:
917
918
919
920
921
922
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
923

924
925
926
    async def collective_rpc(
        self,
        method: str,
927
        timeout: float | None = None,
928
        args: tuple = (),
929
        kwargs: dict | None = None,
930
    ):
931
932
933
934
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
935
936
            method, timeout, args, kwargs
        )
937

938
939
940
941
942
943
944
945
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

946
            logger.info("Engines are still running, waiting for requests to drain...")
947
948
            await asyncio.sleep(1)  # Wait 1 second before checking again

949
950
951
952
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
953

954
955
956
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
957
958
959
960
961
962
963
964
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
965
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
966
        if old_data_parallel_size == new_data_parallel_size:
967
968
969
970
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
971
972
            return
        logger.info(
973
974
975
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
976
977
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
978
979
980
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
981
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
982
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
983
984

        # recreate stat loggers
985
986
987
988
989
990
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
991
                vllm_config=self.vllm_config,
992
                engine_idxs=list(range(new_data_parallel_size)),
993
994
995
                custom_stat_loggers=None,
            )

996
997
    @property
    def is_running(self) -> bool:
998
999
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1000
1001
1002

    @property
    def is_stopped(self) -> bool:
1003
        return self.errored
1004
1005
1006

    @property
    def errored(self) -> bool:
1007
        return self.engine_core.resources.engine_dead or not self.is_running
1008
1009
1010

    @property
    def dead_error(self) -> BaseException:
1011
        return EngineDeadError()