async_llm.py 41.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
21
from vllm.engine.arg_utils import AsyncEngineArgs
22
from vllm.engine.protocol import EngineClient
23
from vllm.inputs import PromptType, StreamingInput
24
25
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
26
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
27
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
28
from vllm.plugins.io_processors import get_io_processor
29
from vllm.pooling_params import PoolingParams
30
from vllm.renderers import BaseRenderer, merge_kwargs
31
32
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
33
from vllm.sampling_params import RequestOutputKind, SamplingParams
34
from vllm.tasks import SupportedTask
35
from vllm.tokenizers import TokenizerLike
36
from vllm.tracing import init_tracer
37
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
38
from vllm.usage.usage_lib import UsageContext
39
40
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
41
from vllm.v1.engine import EngineCoreRequest, PauseMode
42
from vllm.v1.engine.core_client import EngineCoreClient
43
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
44
from vllm.v1.engine.input_processor import InputProcessor
45
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
46
from vllm.v1.engine.parallel_sampling import ParentRequest
47
from vllm.v1.executor import Executor
48
49
50
51
52
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
53
from vllm.v1.metrics.prometheus import shutdown_prometheus
54
from vllm.v1.metrics.stats import IterationStats
55
56
57
58

logger = init_logger(__name__)


59
60
61
62
63
64
65
66
67
68
69
70
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


71
72
73
74
class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
75
        executor_class: type[Executor],
76
77
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
78
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
79
80
81
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
82
        stat_loggers: list[StatLoggerFactory] | None = None,
83
        aggregate_engine_logging: bool = False,
84
        client_addresses: dict[str, str] | None = None,
85
        client_count: int = 1,
86
        client_index: int = 0,
87
    ) -> None:
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
108
109
110
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

111
        self.model_config = vllm_config.model_config
112
        self.vllm_config = vllm_config
113
        self.observability_config = vllm_config.observability_config
114
115
116
117
        tracing_endpoint = self.observability_config.otlp_traces_endpoint
        if tracing_endpoint is not None:
            init_tracer("vllm.llm_engine", tracing_endpoint)

118
        self.log_requests = log_requests
119

120
121
122
123
124
125
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
126
            logger.info(
127
128
129
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
130
            )
131

132
        self.input_processor = InputProcessor(self.vllm_config)
133
134
        self.io_processor = get_io_processor(
            self.vllm_config,
135
            self.model_config.io_processor_plugin,
136
        )
137

138
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
139
        self.output_processor = OutputProcessor(
140
141
142
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
143
        )
144
145
        if tracing_endpoint is not None:
            self.output_processor.tracing_enabled = True
146
147

        # EngineCore (starts the engine in background process).
148
        self.engine_core = EngineCoreClient.make_async_mp_client(
149
150
            vllm_config=vllm_config,
            executor_class=executor_class,
151
            log_stats=self.log_stats,
152
            client_addresses=client_addresses,
153
            client_count=client_count,
154
            client_index=client_index,
155
        )
156
157

        # Loggers.
158
        self.logger_manager: StatLoggerManager | None = None
159
160
161
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
162
                engine_idxs=self.engine_core.engine_ranks_managed,
163
                custom_stat_loggers=custom_stat_loggers,
164
                enable_default_loggers=log_stats,
165
                client_count=client_count,
166
                aggregate_engine_logging=aggregate_engine_logging,
167
168
169
            )
            self.logger_manager.log_engine_initialized()

170
171
172
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False
173
        self._client_count = client_count
174

175
        self.output_handler: asyncio.Task | None = None
176
177
178
179
180
181
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
182

183
        if (
184
185
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
186
        ):
187
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
188
189
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
190
                profiler_dir,
191
            )
192
193
194
195
196
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
197
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
198
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
199
                    profiler_dir,
200
                    worker_name=worker_name,
201
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
202
203
                ),
            )
204
205
206
        else:
            self.profiler = None

207
208
    @classmethod
    def from_vllm_config(
209
210
211
212
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
213
        stat_loggers: list[StatLoggerFactory] | None = None,
214
        enable_log_requests: bool = False,
215
        aggregate_engine_logging: bool = False,
216
        disable_log_stats: bool = False,
217
        client_addresses: dict[str, str] | None = None,
218
219
        client_count: int = 1,
        client_index: int = 0,
220
221
222
223
224
225
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
226
            stat_loggers=stat_loggers,
227
            log_requests=enable_log_requests,
228
            log_stats=not disable_log_stats,
229
            aggregate_engine_logging=aggregate_engine_logging,
230
            usage_context=usage_context,
231
            client_addresses=client_addresses,
232
            client_count=client_count,
233
            client_index=client_index,
234
235
        )

236
237
238
239
240
241
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
242
        stat_loggers: list[StatLoggerFactory] | None = None,
243
    ) -> "AsyncLLM":
244
245
246
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
247
        vllm_config = engine_args.create_engine_config(usage_context)
248
        executor_class = Executor.get_class(vllm_config)
249
250
251
252
253

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
254
            log_requests=engine_args.enable_log_requests,
255
256
257
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
258
            stat_loggers=stat_loggers,
259
260
        )

261
262
263
    def __del__(self):
        self.shutdown()

264
265
266
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

267
268
        shutdown_prometheus()

269
270
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
271

272
273
274
        if input_processor := getattr(self, "input_processor", None):
            input_processor.close()

275
276
277
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
278

279
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
280
281
282
283
284
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = await self.engine_core.get_supported_tasks_async()

        return self._supported_tasks
285

286
287
288
    async def add_request(
        self,
        request_id: str,
289
290
291
292
293
        prompt: EngineCoreRequest
        | PromptType
        | DictPrompt
        | TokPrompt
        | AsyncGenerator[StreamingInput, None],
294
295
296
297
298
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
299
        priority: int = 0,
300
301
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
302
    ) -> RequestOutputCollector:
303
304
        """Add new request to the AsyncLLM."""

305
306
307
        if self.errored:
            raise EngineDeadError()

308
        is_pooling = isinstance(params, PoolingParams)
309

310
311
312
313
314
315
316
317
318
319
320
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )
335

336
337
338
339
340
341
342
343
344
345
346
347
348
349
        if isinstance(prompt, AsyncGenerator):
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

350
        # Convert Input --> Request.
351
352
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
353
354
355
356
357
358
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
359
        else:
360
361
362
363
            if prompt_text is not None:
                raise ValueError(
                    "should only provide prompt_text with EngineCoreRequest"
                )
364
            request = self.input_processor.process_inputs(
365
366
367
                request_id,
                prompt,
                params,
368
369
370
371
372
373
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
374
                supported_tasks=await self.get_supported_tasks(),
375
            )
376
            prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
377

378
379
        self.input_processor.assign_request_id(request)

380
381
382
383
384
385
386
387
388
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

        # Respect pause state before accepting new requests.
        async with self._pause_cond:
            await self._pause_cond.wait_for(lambda: not self._paused)

389
390
391
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

392
393
394
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

395
        if is_pooling or params.n == 1:
396
            await self._add_request(request, prompt_text, None, 0, queue)
397
398
            return queue

399
400
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
401

402
        # Fan out child requests (for n>1).
403
        parent_request = ParentRequest(request)
404
405
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
406
            child_request = request if idx == parent_params.n - 1 else copy(request)
407
            child_request.request_id = request_id
408
            child_request.sampling_params = child_params
409
410
411
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
412
        return queue
413

414
415
416
    async def _add_request(
        self,
        request: EngineCoreRequest,
417
418
        prompt: str | None,
        parent_req: ParentRequest | None,
419
420
421
        index: int,
        queue: RequestOutputCollector,
    ):
422
        # Add the request to OutputProcessor (this process).
423
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
424

425
426
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
427

428
429
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
430

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
480
                    # TODO(nick): Avoid re-validating reused sampling parameters
481
482
483
484
485
486
487
488
489
490
491
492
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
493
494
495
                    prompt_text, _, _ = extract_prompt_components(
                        self.model_config, input_chunk.prompt
                    )
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

532
533
534
535
536
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
537
    async def generate(
538
        self,
539
540
541
542
543
        prompt: EngineCoreRequest
        | PromptType
        | DictPrompt
        | TokPrompt
        | AsyncGenerator[StreamingInput, None],
544
545
        sampling_params: SamplingParams,
        request_id: str,
546
        *,
547
548
549
550
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
551
        priority: int = 0,
552
        data_parallel_rank: int | None = None,
553
554
555
556
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
557
            * 2) Processing the Input.
558
559
560
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

561
562
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
563
564
565
566
567
568
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

569
        q: RequestOutputCollector | None = None
570
        try:
571
572
573
574
575
576
577
578
579
580
581
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
582

583
584
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
585
586
            finished = False
            while not finished:
587
588
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
589
                out = q.get_nowait() or await q.get()
590

591
                # Note: both OutputProcessor and EngineCore handle their
592
                # own request cleanup based on finished.
593
                assert isinstance(out, RequestOutput)
594
595
596
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
597

598
        # If the request is disconnected by the client, generate()
599
600
601
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
602
603
            if q is not None:
                await self.abort(q.request_id, internal=True)
604
605
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
606
            raise
607

608
609
610
611
612
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
613

614
        # Request validation error.
615
        except ValueError as e:
616
            if self.log_requests:
617
                logger.info("Request %s failed (bad request): %s.", request_id, e)
618
            raise
619

620
621
622
623
624
625
626
627
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

628
        # Unexpected error in the generate() task (possibly recoverable).
629
        except Exception as e:
630
631
            if q is not None:
                await self.abort(q.request_id, internal=True)
632
            if self.log_requests:
633
634
635
636
637
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
638
                        "error during printing an exception of class"
639
640
641
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
642
            raise EngineGenerateError() from e
643
644
645
        finally:
            if q is not None:
                q.close()
646
647
648
649
650
651
652
653
654
655
656
657

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
658
        logger_manager = self.logger_manager
659
        input_processor = self.input_processor
660
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
661
662
663
664
665
666
667
668

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

669
670
671
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
672
673
674
675

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
676
677
678
679
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
680
681
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
682
683
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
684
685
686
687
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
688
                        if end < num_outputs:
689
690
691
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
692
693
694
695
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
696

697
698
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

699
700
701
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
702
703
704
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
705
706
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
707
                            mm_cache_stats=input_processor.stat_mm_cache(),
708
709
710
711
712
713
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
714

715
716
717
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
718
        """Abort RequestId in OutputProcessor and EngineCore."""
719

720
721
722
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
723
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
724
        await self.engine_core.abort_requests_async(all_request_ids)
725

726
        if self.log_requests:
727
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
728

729
730
731
    async def pause_generation(
        self,
        *,
732
733
        mode: PauseMode = "abort",
        wait_for_inflight_requests: bool | None = None,
734
735
736
737
738
739
740
741
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
742
743
744
745
746
747
748
749
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED: use mode argument.
                Whether to wait for in-flight requests to complete before pausing.
750
751
752
753
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).

754
755
756
757
758
759
760
761
762
763
        """
        if wait_for_inflight_requests:
            warnings.warn(
                "The `wait_for_inflight_requests` parameter in "
                "`AsyncLLM.pause_generation()` is deprecated. "
                "Please use `mode` argument instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            mode = "wait"
764

765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
        if mode == "keep":
            # Freeze requests in the scheduler - they will resume on
            # resume_generation().
            await self.engine_core.pause_scheduler_async()
        else:
            if self._client_count > 1:
                raise NotImplementedError(
                    "pause_generation is not supported with --api-server-count > 1"
                    " when mode is not 'keep'"
                )
            async with self._pause_cond:
                if not self._paused:
                    self._paused = True

                    if mode == "abort":
                        request_ids = list(self.output_processor.request_states.keys())
                        if request_ids:
                            await self.abort(request_ids, internal=True)
                    elif mode == "wait":
                        if self.output_processor.has_unfinished_requests():
                            await self.output_processor.wait_for_requests_to_drain()
                    else:
                        raise ValueError(f"Invalid mode: {mode}")
788
789
790
791
792

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()
793
            await self.reset_encoder_cache()
794
795
796
797
798

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
799
            await self.engine_core.resume_scheduler_async()
800
801
802
803
804
805
806
807
808
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

809
    async def encode(
810
        self,
811
        prompt: PromptType | DictPrompt | TokPrompt,
812
813
        pooling_params: PoolingParams,
        request_id: str,
814
815
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
816
        priority: int = 0,
817
        tokenization_kwargs: dict[str, Any] | None = None,
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

833
        q: RequestOutputCollector | None = None
834
835
836
837
838
839
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
840
                tokenization_kwargs=tokenization_kwargs,
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
861
862
            if q is not None:
                await self.abort(q.request_id, internal=True)
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
881
882
            if q is not None:
                await self.abort(q.request_id, internal=True)
883
884
885
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
886
887
888
        finally:
            if q is not None:
                q.close()
889

890
    @property
891
    def tokenizer(self) -> TokenizerLike | None:
892
        return self.input_processor.tokenizer
893

894
895
    def get_tokenizer(self) -> TokenizerLike:
        return self.input_processor.get_tokenizer()
896

897
    @property
898
    def renderer(self) -> BaseRenderer:
899
        return self.input_processor.renderer
900
901

    async def is_tracing_enabled(self) -> bool:
902
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
903

904
    async def do_log_stats(self) -> None:
905
906
        if self.logger_manager:
            self.logger_manager.log()
907
908
909

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
910
911
        if self.errored:
            raise self.dead_error
912
913

    async def start_profile(self) -> None:
914
915
916
917
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
918
919

    async def stop_profile(self) -> None:
920
921
922
923
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
924

925
    async def reset_mm_cache(self) -> None:
926
        self.input_processor.clear_mm_cache()
927
928
        await self.engine_core.reset_mm_cache_async()

929
930
931
932
933
934
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
935

936
937
938
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

939
    async def sleep(self, level: int = 1) -> None:
940
        await self.reset_prefix_cache()
941
942
        await self.engine_core.sleep_async(level)

943
944
945
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

946
    async def wake_up(self, tags: list[str] | None = None) -> None:
947
        await self.engine_core.wake_up_async(tags)
948

949
950
951
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

952
953
954
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

955
    async def add_lora(self, lora_request: LoRARequest) -> bool:
956
        """Load a new LoRA adapter into the engine for future requests."""
957
958
959
960
961
962
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

963
    async def list_loras(self) -> set[int]:
964
965
966
967
968
969
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
970

971
972
973
    async def collective_rpc(
        self,
        method: str,
974
        timeout: float | None = None,
975
        args: tuple = (),
976
        kwargs: dict | None = None,
977
    ):
978
979
980
981
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
982
983
            method, timeout, args, kwargs
        )
984

985
986
987
988
989
990
991
992
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

993
            logger.info("Engines are still running, waiting for requests to drain...")
994
995
            await asyncio.sleep(1)  # Wait 1 second before checking again

996
997
998
999
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
1000

1001
1002
1003
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
1004
1005
1006
1007
1008
1009
1010
1011
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
1012
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
1013
        if old_data_parallel_size == new_data_parallel_size:
1014
1015
1016
1017
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
1018
1019
            return
        logger.info(
1020
1021
1022
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
1023
1024
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
1025
1026
1027
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
1028
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
1029
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
1030
1031

        # recreate stat loggers
1032
1033
1034
1035
1036
1037
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
1038
                vllm_config=self.vllm_config,
1039
                engine_idxs=list(range(new_data_parallel_size)),
1040
1041
1042
                custom_stat_loggers=None,
            )

1043
1044
    @property
    def is_running(self) -> bool:
1045
1046
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1047
1048
1049

    @property
    def is_stopped(self) -> bool:
1050
        return self.errored
1051
1052
1053

    @property
    def errored(self) -> bool:
1054
        return self.engine_core.resources.engine_dead or not self.is_running
1055
1056
1057

    @property
    def dead_error(self) -> BaseException:
1058
        return EngineDeadError()
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099

    async def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        from vllm.distributed.weight_transfer.base import (
            WeightTransferInitRequest,
        )

        if isinstance(request, WeightTransferInitRequest):
            init_info_dict = request.init_info
        else:
            raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")

        await self.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """
        Batched weight update for RL training.

        Args:
            request: Weight update request with backend-specific update info
        """

        if isinstance(request, WeightTransferUpdateRequest):
            update_info_dict = request.update_info
        else:
            raise TypeError(
                f"Expected WeightTransferUpdateRequest, got {type(request)}"
            )

        await self.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )