async_llm.py 40.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
21
from vllm.engine.arg_utils import AsyncEngineArgs
22
from vllm.engine.protocol import EngineClient
23
from vllm.inputs import PromptType, StreamingInput
24
25
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
26
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
27
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
28
from vllm.plugins.io_processors import get_io_processor
29
from vllm.pooling_params import PoolingParams
30
from vllm.renderers import BaseRenderer, merge_kwargs
31
32
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
33
from vllm.sampling_params import RequestOutputKind, SamplingParams
34
from vllm.tasks import SupportedTask
35
from vllm.tokenizers import TokenizerLike
36
from vllm.tracing import init_tracer
37
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
38
from vllm.usage.usage_lib import UsageContext
39
40
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
41
from vllm.v1.engine import EngineCoreRequest
42
from vllm.v1.engine.core_client import EngineCoreClient
43
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
44
from vllm.v1.engine.input_processor import InputProcessor
45
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
46
from vllm.v1.engine.parallel_sampling import ParentRequest
47
from vllm.v1.executor import Executor
48
49
50
51
52
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
53
from vllm.v1.metrics.prometheus import shutdown_prometheus
54
from vllm.v1.metrics.stats import IterationStats
55
56
57
58

logger = init_logger(__name__)


59
60
61
62
63
64
65
66
67
68
69
70
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


71
72
73
74
class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
75
        executor_class: type[Executor],
76
77
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
78
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
79
80
81
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
82
        stat_loggers: list[StatLoggerFactory] | None = None,
83
        aggregate_engine_logging: bool = False,
84
        client_addresses: dict[str, str] | None = None,
85
        client_count: int = 1,
86
        client_index: int = 0,
87
    ) -> None:
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
108
109
110
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

111
        self.model_config = vllm_config.model_config
112
        self.vllm_config = vllm_config
113
        self.observability_config = vllm_config.observability_config
114
115
116
117
        tracing_endpoint = self.observability_config.otlp_traces_endpoint
        if tracing_endpoint is not None:
            init_tracer("vllm.llm_engine", tracing_endpoint)

118
        self.log_requests = log_requests
119

120
121
122
123
124
125
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
126
            logger.info(
127
128
129
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
130
            )
131

132
        self.input_processor = InputProcessor(self.vllm_config)
133
134
        self.io_processor = get_io_processor(
            self.vllm_config,
135
            self.model_config.io_processor_plugin,
136
        )
137

138
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
139
        self.output_processor = OutputProcessor(
140
141
142
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
143
        )
144
145
        if tracing_endpoint is not None:
            self.output_processor.tracing_enabled = True
146
147

        # EngineCore (starts the engine in background process).
148
        self.engine_core = EngineCoreClient.make_async_mp_client(
149
150
            vllm_config=vllm_config,
            executor_class=executor_class,
151
            log_stats=self.log_stats,
152
            client_addresses=client_addresses,
153
            client_count=client_count,
154
            client_index=client_index,
155
        )
156
157

        # Loggers.
158
        self.logger_manager: StatLoggerManager | None = None
159
160
161
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
162
                engine_idxs=self.engine_core.engine_ranks_managed,
163
                custom_stat_loggers=custom_stat_loggers,
164
                enable_default_loggers=log_stats,
165
                client_count=client_count,
166
                aggregate_engine_logging=aggregate_engine_logging,
167
168
169
            )
            self.logger_manager.log_engine_initialized()

170
171
172
173
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False

174
        self.output_handler: asyncio.Task | None = None
175
176
177
178
179
180
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
181

182
        if (
183
184
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
185
        ):
186
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
187
188
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
189
                profiler_dir,
190
            )
191
192
193
194
195
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
196
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
197
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
198
                    profiler_dir,
199
                    worker_name=worker_name,
200
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
201
202
                ),
            )
203
204
205
        else:
            self.profiler = None

206
207
    @classmethod
    def from_vllm_config(
208
209
210
211
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
212
        stat_loggers: list[StatLoggerFactory] | None = None,
213
        enable_log_requests: bool = False,
214
        aggregate_engine_logging: bool = False,
215
        disable_log_stats: bool = False,
216
        client_addresses: dict[str, str] | None = None,
217
218
        client_count: int = 1,
        client_index: int = 0,
219
220
221
222
223
224
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
225
            stat_loggers=stat_loggers,
226
            log_requests=enable_log_requests,
227
            log_stats=not disable_log_stats,
228
            aggregate_engine_logging=aggregate_engine_logging,
229
            usage_context=usage_context,
230
            client_addresses=client_addresses,
231
            client_count=client_count,
232
            client_index=client_index,
233
234
        )

235
236
237
238
239
240
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
241
        stat_loggers: list[StatLoggerFactory] | None = None,
242
    ) -> "AsyncLLM":
243
244
245
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
246
        vllm_config = engine_args.create_engine_config(usage_context)
247
        executor_class = Executor.get_class(vllm_config)
248
249
250
251
252

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
253
            log_requests=engine_args.enable_log_requests,
254
255
256
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
257
            stat_loggers=stat_loggers,
258
259
        )

260
261
262
    def __del__(self):
        self.shutdown()

263
264
265
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

266
267
        shutdown_prometheus()

268
269
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
270

271
272
273
        if input_processor := getattr(self, "input_processor", None):
            input_processor.close()

274
275
276
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
277

278
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
279
280
281
282
283
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = await self.engine_core.get_supported_tasks_async()

        return self._supported_tasks
284

285
286
287
    async def add_request(
        self,
        request_id: str,
288
289
290
291
292
        prompt: EngineCoreRequest
        | PromptType
        | DictPrompt
        | TokPrompt
        | AsyncGenerator[StreamingInput, None],
293
294
295
296
297
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
298
        priority: int = 0,
299
300
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
301
    ) -> RequestOutputCollector:
302
303
        """Add new request to the AsyncLLM."""

304
305
306
        if self.errored:
            raise EngineDeadError()

307
        is_pooling = isinstance(params, PoolingParams)
308

309
310
311
312
313
314
315
316
317
318
319
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

320
321
322
323
324
325
326
327
328
329
330
331
332
333
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )
334

335
336
337
338
339
340
341
342
343
344
345
346
347
348
        if isinstance(prompt, AsyncGenerator):
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

349
        # Convert Input --> Request.
350
351
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
352
353
354
355
356
357
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
358
        else:
359
360
361
362
            if prompt_text is not None:
                raise ValueError(
                    "should only provide prompt_text with EngineCoreRequest"
                )
363
            request = self.input_processor.process_inputs(
364
365
366
                request_id,
                prompt,
                params,
367
368
369
370
371
372
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
373
                supported_tasks=await self.get_supported_tasks(),
374
            )
375
            prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
376

377
378
        self.input_processor.assign_request_id(request)

379
380
381
382
383
384
385
386
387
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

        # Respect pause state before accepting new requests.
        async with self._pause_cond:
            await self._pause_cond.wait_for(lambda: not self._paused)

388
389
390
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

391
392
393
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

394
        if is_pooling or params.n == 1:
395
            await self._add_request(request, prompt_text, None, 0, queue)
396
397
            return queue

398
399
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
400

401
        # Fan out child requests (for n>1).
402
        parent_request = ParentRequest(request)
403
404
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
405
            child_request = request if idx == parent_params.n - 1 else copy(request)
406
            child_request.request_id = request_id
407
            child_request.sampling_params = child_params
408
409
410
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
411
        return queue
412

413
414
415
    async def _add_request(
        self,
        request: EngineCoreRequest,
416
417
        prompt: str | None,
        parent_req: ParentRequest | None,
418
419
420
        index: int,
        queue: RequestOutputCollector,
    ):
421
        # Add the request to OutputProcessor (this process).
422
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
423

424
425
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
426

427
428
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
429

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
479
                    # TODO(nick): Avoid re-validating reused sampling parameters
480
481
482
483
484
485
486
487
488
489
490
491
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
492
493
494
                    prompt_text, _, _ = extract_prompt_components(
                        self.model_config, input_chunk.prompt
                    )
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

531
532
533
534
535
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
536
    async def generate(
537
        self,
538
539
540
541
542
        prompt: EngineCoreRequest
        | PromptType
        | DictPrompt
        | TokPrompt
        | AsyncGenerator[StreamingInput, None],
543
544
        sampling_params: SamplingParams,
        request_id: str,
545
        *,
546
547
548
549
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
550
        priority: int = 0,
551
        data_parallel_rank: int | None = None,
552
553
554
555
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
556
            * 2) Processing the Input.
557
558
559
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

560
561
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
562
563
564
565
566
567
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

568
        q: RequestOutputCollector | None = None
569
        try:
570
571
572
573
574
575
576
577
578
579
580
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
581

582
583
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
584
585
            finished = False
            while not finished:
586
587
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
588
                out = q.get_nowait() or await q.get()
589

590
                # Note: both OutputProcessor and EngineCore handle their
591
                # own request cleanup based on finished.
592
                assert isinstance(out, RequestOutput)
593
594
595
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
596

597
        # If the request is disconnected by the client, generate()
598
599
600
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
601
602
            if q is not None:
                await self.abort(q.request_id, internal=True)
603
604
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
605
            raise
606

607
608
609
610
611
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
612

613
        # Request validation error.
614
        except ValueError as e:
615
            if self.log_requests:
616
                logger.info("Request %s failed (bad request): %s.", request_id, e)
617
            raise
618

619
620
621
622
623
624
625
626
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

627
        # Unexpected error in the generate() task (possibly recoverable).
628
        except Exception as e:
629
630
            if q is not None:
                await self.abort(q.request_id, internal=True)
631
            if self.log_requests:
632
633
634
635
636
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
637
                        "error during printing an exception of class"
638
639
640
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
641
            raise EngineGenerateError() from e
642
643
644
        finally:
            if q is not None:
                q.close()
645
646
647
648
649
650
651
652
653
654
655
656

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
657
        logger_manager = self.logger_manager
658
        input_processor = self.input_processor
659
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
660
661
662
663
664
665
666
667

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

668
669
670
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
671
672
673
674

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
675
676
677
678
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
679
680
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
681
682
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
683
684
685
686
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
687
                        if end < num_outputs:
688
689
690
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
691
692
693
694
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
695

696
697
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

698
699
700
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
701
702
703
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
704
705
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
706
                            mm_cache_stats=input_processor.stat_mm_cache(),
707
708
709
710
711
712
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
713

714
715
716
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
717
        """Abort RequestId in OutputProcessor and EngineCore."""
718

719
720
721
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
722
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
723
        await self.engine_core.abort_requests_async(all_request_ids)
724

725
        if self.log_requests:
726
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
727

728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight
                requests to finish before pausing. When ``False`` (default),
                immediately aborts any in-flight requests.
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).
        """

        async with self._pause_cond:
            if self._paused:
                return
            self._paused = True

        if not wait_for_inflight_requests:
            request_ids = list(self.output_processor.request_states.keys())
            if request_ids:
756
                await self.abort(request_ids, internal=True)
757
758
759
760
761
762
763
764
765

        # Wait for running requests to drain before clearing cache.
        if self.output_processor.has_unfinished_requests():
            await self.output_processor.wait_for_requests_to_drain()

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()
766
            await self.reset_encoder_cache()
767
768
769
770
771
772
773
774
775
776
777
778
779
780

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

781
    async def encode(
782
        self,
783
        prompt: PromptType | DictPrompt | TokPrompt,
784
785
        pooling_params: PoolingParams,
        request_id: str,
786
787
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
788
        priority: int = 0,
789
        tokenization_kwargs: dict[str, Any] | None = None,
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

805
        q: RequestOutputCollector | None = None
806
807
808
809
810
811
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
812
                tokenization_kwargs=tokenization_kwargs,
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
833
834
            if q is not None:
                await self.abort(q.request_id, internal=True)
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
853
854
            if q is not None:
                await self.abort(q.request_id, internal=True)
855
856
857
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
858
859
860
        finally:
            if q is not None:
                q.close()
861

862
    @property
863
    def tokenizer(self) -> TokenizerLike | None:
864
        return self.input_processor.tokenizer
865

866
867
    def get_tokenizer(self) -> TokenizerLike:
        return self.input_processor.get_tokenizer()
868

869
    @property
870
    def renderer(self) -> BaseRenderer:
871
        return self.input_processor.renderer
872
873

    async def is_tracing_enabled(self) -> bool:
874
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
875

876
    async def do_log_stats(self) -> None:
877
878
        if self.logger_manager:
            self.logger_manager.log()
879
880
881

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
882
883
        if self.errored:
            raise self.dead_error
884
885

    async def start_profile(self) -> None:
886
887
888
889
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
890
891

    async def stop_profile(self) -> None:
892
893
894
895
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
896

897
    async def reset_mm_cache(self) -> None:
898
        self.input_processor.clear_mm_cache()
899
900
        await self.engine_core.reset_mm_cache_async()

901
902
903
904
905
906
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
907

908
909
910
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

911
    async def sleep(self, level: int = 1) -> None:
912
        await self.reset_prefix_cache()
913
914
        await self.engine_core.sleep_async(level)

915
916
917
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

918
    async def wake_up(self, tags: list[str] | None = None) -> None:
919
        await self.engine_core.wake_up_async(tags)
920

921
922
923
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

924
925
926
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

927
    async def add_lora(self, lora_request: LoRARequest) -> bool:
928
        """Load a new LoRA adapter into the engine for future requests."""
929
930
931
932
933
934
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

935
    async def list_loras(self) -> set[int]:
936
937
938
939
940
941
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
942

943
944
945
    async def collective_rpc(
        self,
        method: str,
946
        timeout: float | None = None,
947
        args: tuple = (),
948
        kwargs: dict | None = None,
949
    ):
950
951
952
953
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
954
955
            method, timeout, args, kwargs
        )
956

957
958
959
960
961
962
963
964
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

965
            logger.info("Engines are still running, waiting for requests to drain...")
966
967
            await asyncio.sleep(1)  # Wait 1 second before checking again

968
969
970
971
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
972

973
974
975
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
976
977
978
979
980
981
982
983
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
984
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
985
        if old_data_parallel_size == new_data_parallel_size:
986
987
988
989
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
990
991
            return
        logger.info(
992
993
994
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
995
996
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
997
998
999
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
1000
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
1001
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
1002
1003

        # recreate stat loggers
1004
1005
1006
1007
1008
1009
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
1010
                vllm_config=self.vllm_config,
1011
                engine_idxs=list(range(new_data_parallel_size)),
1012
1013
1014
                custom_stat_loggers=None,
            )

1015
1016
    @property
    def is_running(self) -> bool:
1017
1018
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1019
1020
1021

    @property
    def is_stopped(self) -> bool:
1022
        return self.errored
1023
1024
1025

    @property
    def errored(self) -> bool:
1026
        return self.engine_core.resources.engine_dead or not self.is_running
1027
1028
1029

    @property
    def dead_error(self) -> BaseException:
1030
        return EngineDeadError()
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071

    async def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        from vllm.distributed.weight_transfer.base import (
            WeightTransferInitRequest,
        )

        if isinstance(request, WeightTransferInitRequest):
            init_info_dict = request.init_info
        else:
            raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")

        await self.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """
        Batched weight update for RL training.

        Args:
            request: Weight update request with backend-specific update info
        """

        if isinstance(request, WeightTransferUpdateRequest):
            update_info_dict = request.update_info
        else:
            raise TypeError(
                f"Expected WeightTransferUpdateRequest, got {type(request)}"
            )

        await self.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )