async_llm.py 42 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
21
from vllm.engine.arg_utils import AsyncEngineArgs
22
from vllm.engine.protocol import EngineClient
23
from vllm.inputs import PromptType, StreamingInput
24
25
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
26
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
27
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
28
from vllm.plugins.io_processors import get_io_processor
29
from vllm.pooling_params import PoolingParams
30
from vllm.renderers import BaseRenderer, merge_kwargs
31
32
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
33
from vllm.sampling_params import RequestOutputKind, SamplingParams
34
from vllm.tasks import SupportedTask
35
from vllm.tokenizers import TokenizerLike
36
from vllm.tracing import init_tracer
37
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
38
from vllm.usage.usage_lib import UsageContext
39
40
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
41
from vllm.v1.engine import EngineCoreRequest, PauseMode
42
from vllm.v1.engine.core_client import EngineCoreClient
43
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
44
from vllm.v1.engine.input_processor import InputProcessor
45
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
46
from vllm.v1.engine.parallel_sampling import ParentRequest
47
from vllm.v1.executor import Executor
48
49
50
51
52
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
53
from vllm.v1.metrics.prometheus import shutdown_prometheus
54
from vllm.v1.metrics.stats import IterationStats
55
56
57
58

logger = init_logger(__name__)


59
60
61
62
63
64
65
66
67
68
69
70
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


71
class AsyncLLM(EngineClient):
72
73
    """An asynchronous wrapper for the vLLM engine."""

74
75
76
    def __init__(
        self,
        vllm_config: VllmConfig,
77
        executor_class: type[Executor],
78
79
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
80
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
81
82
83
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
84
        stat_loggers: list[StatLoggerFactory] | None = None,
85
        aggregate_engine_logging: bool = False,
86
        client_addresses: dict[str, str] | None = None,
87
        client_count: int = 1,
88
        client_index: int = 0,
89
    ) -> None:
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
110
111
112
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

113
        self.model_config = vllm_config.model_config
114
        self.vllm_config = vllm_config
115
        self.observability_config = vllm_config.observability_config
116
117
118
119
        tracing_endpoint = self.observability_config.otlp_traces_endpoint
        if tracing_endpoint is not None:
            init_tracer("vllm.llm_engine", tracing_endpoint)

120
        self.log_requests = log_requests
121

122
123
124
125
126
127
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
128
            logger.info(
129
130
131
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
132
            )
133

134
        self.input_processor = InputProcessor(self.vllm_config)
135
136
        self.io_processor = get_io_processor(
            self.vllm_config,
137
            self.model_config.io_processor_plugin,
138
        )
139

140
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
141
        self.output_processor = OutputProcessor(
142
143
144
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
145
        )
146
147
        if tracing_endpoint is not None:
            self.output_processor.tracing_enabled = True
148
149

        # EngineCore (starts the engine in background process).
150
        self.engine_core = EngineCoreClient.make_async_mp_client(
151
152
            vllm_config=vllm_config,
            executor_class=executor_class,
153
            log_stats=self.log_stats,
154
            client_addresses=client_addresses,
155
            client_count=client_count,
156
            client_index=client_index,
157
        )
158
159

        # Loggers.
160
        self.logger_manager: StatLoggerManager | None = None
161
162
163
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
164
                engine_idxs=self.engine_core.engine_ranks_managed,
165
                custom_stat_loggers=custom_stat_loggers,
166
                enable_default_loggers=log_stats,
167
                client_count=client_count,
168
                aggregate_engine_logging=aggregate_engine_logging,
169
170
171
            )
            self.logger_manager.log_engine_initialized()

172
173
174
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False
175
        self._client_count = client_count
176

177
        self.output_handler: asyncio.Task | None = None
178
179
180
181
182
183
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
184

185
        if (
186
187
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
188
        ):
189
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
190
191
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
192
                profiler_dir,
193
            )
194
195
196
197
198
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
199
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
200
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
201
                    profiler_dir,
202
                    worker_name=worker_name,
203
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
204
205
                ),
            )
206
207
208
        else:
            self.profiler = None

209
210
    @classmethod
    def from_vllm_config(
211
212
213
214
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
215
        stat_loggers: list[StatLoggerFactory] | None = None,
216
        enable_log_requests: bool = False,
217
        aggregate_engine_logging: bool = False,
218
        disable_log_stats: bool = False,
219
        client_addresses: dict[str, str] | None = None,
220
221
        client_count: int = 1,
        client_index: int = 0,
222
223
224
225
226
227
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
228
            stat_loggers=stat_loggers,
229
            log_requests=enable_log_requests,
230
            log_stats=not disable_log_stats,
231
            aggregate_engine_logging=aggregate_engine_logging,
232
            usage_context=usage_context,
233
            client_addresses=client_addresses,
234
            client_count=client_count,
235
            client_index=client_index,
236
237
        )

238
239
240
241
242
243
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
244
        stat_loggers: list[StatLoggerFactory] | None = None,
245
    ) -> "AsyncLLM":
246
247
248
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
249
        vllm_config = engine_args.create_engine_config(usage_context)
250
        executor_class = Executor.get_class(vllm_config)
251
252
253
254
255

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
256
            log_requests=engine_args.enable_log_requests,
257
258
259
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
260
            stat_loggers=stat_loggers,
261
262
        )

263
264
265
    def __del__(self):
        self.shutdown()

266
267
268
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

269
270
        shutdown_prometheus()

271
272
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
273

274
275
276
        if input_processor := getattr(self, "input_processor", None):
            input_processor.close()

277
278
279
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
280

281
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
282
283
284
285
286
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = await self.engine_core.get_supported_tasks_async()

        return self._supported_tasks
287

288
289
290
    async def add_request(
        self,
        request_id: str,
291
292
293
294
295
        prompt: EngineCoreRequest
        | PromptType
        | DictPrompt
        | TokPrompt
        | AsyncGenerator[StreamingInput, None],
296
297
298
299
300
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
301
        priority: int = 0,
302
303
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
304
    ) -> RequestOutputCollector:
305
306
        """Add new request to the AsyncLLM."""

307
308
309
        if self.errored:
            raise EngineDeadError()

310
        is_pooling = isinstance(params, PoolingParams)
311

312
313
314
315
316
317
318
319
320
321
322
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

323
324
325
326
327
328
329
330
331
332
333
334
335
336
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )
337

338
339
340
341
342
343
344
345
346
347
348
349
350
351
        if isinstance(prompt, AsyncGenerator):
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

352
        # Convert Input --> Request.
353
354
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
355
356
357
358
359
360
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
361
        else:
362
363
364
365
            if prompt_text is not None:
                raise ValueError(
                    "should only provide prompt_text with EngineCoreRequest"
                )
366
            request = self.input_processor.process_inputs(
367
368
369
                request_id,
                prompt,
                params,
370
371
372
373
374
375
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
376
                supported_tasks=await self.get_supported_tasks(),
377
            )
378
            prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
379

380
381
        self.input_processor.assign_request_id(request)

382
383
384
385
386
387
388
389
390
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

        # Respect pause state before accepting new requests.
        async with self._pause_cond:
            await self._pause_cond.wait_for(lambda: not self._paused)

391
392
393
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

394
395
396
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

397
        if is_pooling or params.n == 1:
398
            await self._add_request(request, prompt_text, None, 0, queue)
399
400
            return queue

401
402
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
403

404
        # Fan out child requests (for n>1).
405
        parent_request = ParentRequest(request)
406
407
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
408
            child_request = request if idx == parent_params.n - 1 else copy(request)
409
            child_request.request_id = request_id
410
            child_request.sampling_params = child_params
411
412
413
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
414
        return queue
415

416
417
418
    async def _add_request(
        self,
        request: EngineCoreRequest,
419
420
        prompt: str | None,
        parent_req: ParentRequest | None,
421
422
423
        index: int,
        queue: RequestOutputCollector,
    ):
424
        # Add the request to OutputProcessor (this process).
425
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
426

427
428
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
429

430
431
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
432

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
482
                    # TODO(nick): Avoid re-validating reused sampling parameters
483
484
485
486
487
488
489
490
491
492
493
494
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
495
496
497
                    prompt_text, _, _ = extract_prompt_components(
                        self.model_config, input_chunk.prompt
                    )
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

534
535
536
537
538
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
539
    async def generate(
540
        self,
541
542
543
544
545
        prompt: EngineCoreRequest
        | PromptType
        | DictPrompt
        | TokPrompt
        | AsyncGenerator[StreamingInput, None],
546
547
        sampling_params: SamplingParams,
        request_id: str,
548
        *,
549
550
551
552
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
553
        priority: int = 0,
554
        data_parallel_rank: int | None = None,
555
556
557
558
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
559
            * 2) Processing the Input.
560
561
562
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

563
564
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
565
566
567
568
569
570
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

571
        q: RequestOutputCollector | None = None
572
        try:
573
574
575
576
577
578
579
580
581
582
583
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
584

585
586
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
587
588
            finished = False
            while not finished:
589
590
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
591
                out = q.get_nowait() or await q.get()
592

593
                # Note: both OutputProcessor and EngineCore handle their
594
                # own request cleanup based on finished.
595
                assert isinstance(out, RequestOutput)
596
597
598
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
599

600
        # If the request is disconnected by the client, generate()
601
602
603
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
604
605
            if q is not None:
                await self.abort(q.request_id, internal=True)
606
607
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
608
            raise
609

610
611
612
613
614
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
615

616
        # Request validation error.
617
        except ValueError as e:
618
            if self.log_requests:
619
                logger.info("Request %s failed (bad request): %s.", request_id, e)
620
            raise
621

622
623
624
625
626
627
628
629
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

630
        # Unexpected error in the generate() task (possibly recoverable).
631
        except Exception as e:
632
633
            if q is not None:
                await self.abort(q.request_id, internal=True)
634
            if self.log_requests:
635
636
637
638
639
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
640
                        "error during printing an exception of class"
641
642
643
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
644
            raise EngineGenerateError() from e
645
646
647
        finally:
            if q is not None:
                q.close()
648
649
650
651
652
653
654
655
656
657
658
659

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
660
        logger_manager = self.logger_manager
661
        input_processor = self.input_processor
662
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
663
664
665
666
667
668
669
670

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

671
672
673
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
674
675
676
677

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
678
679
680
681
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
682
683
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
684
685
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
686
687
688
689
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
690
                        if end < num_outputs:
691
692
693
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
694
695
696
697
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
698

699
700
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

701
702
703
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
704
705
706
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
707
708
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
709
                            mm_cache_stats=input_processor.stat_mm_cache(),
710
711
712
713
714
715
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
716

717
718
719
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
720
        """Abort RequestId in OutputProcessor and EngineCore."""
721

722
723
724
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
725
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
726
        await self.engine_core.abort_requests_async(all_request_ids)
727

728
        if self.log_requests:
729
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
730

731
732
733
    async def pause_generation(
        self,
        *,
734
735
        mode: PauseMode = "abort",
        wait_for_inflight_requests: bool | None = None,
736
737
738
739
740
741
742
743
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
744
745
746
747
748
749
750
751
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED: use mode argument.
                Whether to wait for in-flight requests to complete before pausing.
752
753
754
755
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).

756
757
758
759
760
761
762
763
764
765
        """
        if wait_for_inflight_requests:
            warnings.warn(
                "The `wait_for_inflight_requests` parameter in "
                "`AsyncLLM.pause_generation()` is deprecated. "
                "Please use `mode` argument instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            mode = "wait"
766

767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
        if mode == "keep":
            # Freeze requests in the scheduler - they will resume on
            # resume_generation().
            await self.engine_core.pause_scheduler_async()
        else:
            if self._client_count > 1:
                raise NotImplementedError(
                    "pause_generation is not supported with --api-server-count > 1"
                    " when mode is not 'keep'"
                )
            async with self._pause_cond:
                if not self._paused:
                    self._paused = True

                    if mode == "abort":
                        request_ids = list(self.output_processor.request_states.keys())
                        if request_ids:
                            await self.abort(request_ids, internal=True)
                    elif mode == "wait":
                        if self.output_processor.has_unfinished_requests():
                            await self.output_processor.wait_for_requests_to_drain()
                    else:
                        raise ValueError(f"Invalid mode: {mode}")
790
791
792
793
794

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()
795
            await self.reset_encoder_cache()
796
797
798
799
800

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
801
            await self.engine_core.resume_scheduler_async()
802
803
804
805
806
807
808
809
810
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

811
    async def encode(
812
        self,
813
        prompt: PromptType | DictPrompt | TokPrompt,
814
815
        pooling_params: PoolingParams,
        request_id: str,
816
817
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
818
        priority: int = 0,
819
        tokenization_kwargs: dict[str, Any] | None = None,
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

835
        q: RequestOutputCollector | None = None
836
837
838
839
840
841
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
842
                tokenization_kwargs=tokenization_kwargs,
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
863
864
            if q is not None:
                await self.abort(q.request_id, internal=True)
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
883
884
            if q is not None:
                await self.abort(q.request_id, internal=True)
885
886
887
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
888
889
890
        finally:
            if q is not None:
                q.close()
891

892
    @property
893
    def tokenizer(self) -> TokenizerLike | None:
894
        return self.input_processor.tokenizer
895

896
897
    def get_tokenizer(self) -> TokenizerLike:
        return self.input_processor.get_tokenizer()
898

899
    @property
900
    def renderer(self) -> BaseRenderer:
901
        return self.input_processor.renderer
902
903

    async def is_tracing_enabled(self) -> bool:
904
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
905

906
    async def do_log_stats(self) -> None:
907
908
        if self.logger_manager:
            self.logger_manager.log()
909
910
911

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
912
913
        if self.errored:
            raise self.dead_error
914
915

    async def start_profile(self) -> None:
916
917
918
919
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
920
921

    async def stop_profile(self) -> None:
922
923
924
925
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
926

927
    async def reset_mm_cache(self) -> None:
928
        self.input_processor.clear_mm_cache()
929
930
        await self.engine_core.reset_mm_cache_async()

931
932
933
934
935
936
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
937

938
939
940
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

941
    async def sleep(self, level: int = 1) -> None:
942
        await self.reset_prefix_cache()
943
944
        await self.engine_core.sleep_async(level)

945
946
947
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

948
    async def wake_up(self, tags: list[str] | None = None) -> None:
949
        await self.engine_core.wake_up_async(tags)
950

951
952
953
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

954
955
956
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

957
    async def add_lora(self, lora_request: LoRARequest) -> bool:
958
        """Load a new LoRA adapter into the engine for future requests."""
959
960
961
962
963
964
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

965
    async def list_loras(self) -> set[int]:
966
967
968
969
970
971
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
972

973
974
975
    async def collective_rpc(
        self,
        method: str,
976
        timeout: float | None = None,
977
        args: tuple = (),
978
        kwargs: dict | None = None,
979
    ):
980
981
982
983
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
984
985
            method, timeout, args, kwargs
        )
986

987
988
989
990
991
992
993
994
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

995
            logger.info("Engines are still running, waiting for requests to drain...")
996
997
            await asyncio.sleep(1)  # Wait 1 second before checking again

998
999
1000
1001
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
1002

1003
1004
1005
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
1006
1007
1008
1009
1010
1011
1012
1013
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
1014
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
1015
        if old_data_parallel_size == new_data_parallel_size:
1016
1017
1018
1019
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
1020
1021
            return
        logger.info(
1022
1023
1024
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
1025
1026
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
1027
1028
1029
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
1030
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
1031
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
1032
1033

        # recreate stat loggers
1034
1035
1036
1037
1038
1039
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
1040
                vllm_config=self.vllm_config,
1041
                engine_idxs=list(range(new_data_parallel_size)),
1042
1043
1044
                custom_stat_loggers=None,
            )

1045
1046
    @property
    def is_running(self) -> bool:
1047
1048
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1049
1050
1051

    @property
    def is_stopped(self) -> bool:
1052
        return self.errored
1053
1054
1055

    @property
    def errored(self) -> bool:
1056
        return self.engine_core.resources.engine_dead or not self.is_running
1057
1058
1059

    @property
    def dead_error(self) -> BaseException:
1060
        return EngineDeadError()
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101

    async def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        from vllm.distributed.weight_transfer.base import (
            WeightTransferInitRequest,
        )

        if isinstance(request, WeightTransferInitRequest):
            init_info_dict = request.init_info
        else:
            raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")

        await self.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """
        Batched weight update for RL training.

        Args:
            request: Weight update request with backend-specific update info
        """

        if isinstance(request, WeightTransferUpdateRequest):
            update_info_dict = request.update_info
        else:
            raise TypeError(
                f"Expected WeightTransferUpdateRequest, got {type(request)}"
            )

        await self.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )