async_llm.py 40.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
21
from vllm.engine.arg_utils import AsyncEngineArgs
22
from vllm.engine.protocol import EngineClient, StreamingInput
23
from vllm.inputs import ProcessorInputs, PromptType
24
25
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
26
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
27
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
28
from vllm.plugins.io_processors import get_io_processor
29
from vllm.pooling_params import PoolingParams
30
from vllm.renderers import renderer_from_config
31
from vllm.renderers.inputs.preprocess import extract_prompt_components
32
from vllm.sampling_params import RequestOutputKind, SamplingParams
33
from vllm.tasks import SupportedTask
34
from vllm.tokenizers import TokenizerLike
35
from vllm.tracing import init_tracer
36
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
37
from vllm.usage.usage_lib import UsageContext
38
39
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
40
from vllm.v1.engine import EngineCoreRequest, PauseMode
41
from vllm.v1.engine.core_client import EngineCoreClient
42
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
43
from vllm.v1.engine.input_processor import InputProcessor
44
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
45
from vllm.v1.engine.parallel_sampling import ParentRequest
46
from vllm.v1.executor import Executor
47
48
49
50
51
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
52
from vllm.v1.metrics.prometheus import shutdown_prometheus
53
from vllm.v1.metrics.stats import IterationStats
54
55
56
57

logger = init_logger(__name__)


58
59
60
61
62
63
64
65
66
67
68
69
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


70
class AsyncLLM(EngineClient):
71
72
    """An asynchronous wrapper for the vLLM engine."""

73
74
75
    def __init__(
        self,
        vllm_config: VllmConfig,
76
        executor_class: type[Executor],
77
78
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
79
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
80
81
82
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
83
        stat_loggers: list[StatLoggerFactory] | None = None,
84
        aggregate_engine_logging: bool = False,
85
        client_addresses: dict[str, str] | None = None,
86
        client_count: int = 1,
87
        client_index: int = 0,
88
    ) -> None:
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
109
110
111
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

112
        self.vllm_config = vllm_config
113
        self.model_config = vllm_config.model_config
114
        self.observability_config = vllm_config.observability_config
115

116
117
118
119
        tracing_endpoint = self.observability_config.otlp_traces_endpoint
        if tracing_endpoint is not None:
            init_tracer("vllm.llm_engine", tracing_endpoint)

120
        self.log_requests = log_requests
121

122
123
124
125
126
127
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
128
            logger.info(
129
130
131
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
132
            )
133

134
        self.renderer = renderer = renderer_from_config(self.vllm_config)
135
136
        self.io_processor = get_io_processor(
            self.vllm_config,
137
            self.model_config.io_processor_plugin,
138
        )
139

140
141
142
143
        # Convert TokPrompt --> EngineCoreRequest.
        self.input_processor = InputProcessor(self.vllm_config, renderer)

        # Converts EngineCoreOutputs --> RequestOutput.
144
        self.output_processor = OutputProcessor(
145
            renderer.tokenizer,
146
147
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
148
            tracing_enabled=tracing_endpoint is not None,
149
        )
150
151

        # EngineCore (starts the engine in background process).
152
        self.engine_core = EngineCoreClient.make_async_mp_client(
153
154
            vllm_config=vllm_config,
            executor_class=executor_class,
155
            log_stats=self.log_stats,
156
            client_addresses=client_addresses,
157
            client_count=client_count,
158
            client_index=client_index,
159
        )
160
161

        # Loggers.
162
        self.logger_manager: StatLoggerManager | None = None
163
164
165
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
166
                engine_idxs=self.engine_core.engine_ranks_managed,
167
                custom_stat_loggers=custom_stat_loggers,
168
                enable_default_loggers=log_stats,
169
                client_count=client_count,
170
                aggregate_engine_logging=aggregate_engine_logging,
171
172
173
            )
            self.logger_manager.log_engine_initialized()

174
        self._client_count = client_count
175

176
        self.output_handler: asyncio.Task | None = None
177
178
179
180
181
182
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
183

184
        if (
185
186
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
187
        ):
188
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
189
190
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
191
                profiler_dir,
192
            )
193
194
195
196
197
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
198
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
199
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
200
                    profiler_dir,
201
                    worker_name=worker_name,
202
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
203
204
                ),
            )
205
206
207
        else:
            self.profiler = None

208
209
    @classmethod
    def from_vllm_config(
210
211
212
213
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
214
        stat_loggers: list[StatLoggerFactory] | None = None,
215
        enable_log_requests: bool = False,
216
        aggregate_engine_logging: bool = False,
217
        disable_log_stats: bool = False,
218
        client_addresses: dict[str, str] | None = None,
219
220
        client_count: int = 1,
        client_index: int = 0,
221
222
223
224
225
226
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
227
            stat_loggers=stat_loggers,
228
            log_requests=enable_log_requests,
229
            log_stats=not disable_log_stats,
230
            aggregate_engine_logging=aggregate_engine_logging,
231
            usage_context=usage_context,
232
            client_addresses=client_addresses,
233
            client_count=client_count,
234
            client_index=client_index,
235
236
        )

237
238
239
240
241
242
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
243
        stat_loggers: list[StatLoggerFactory] | None = None,
244
    ) -> "AsyncLLM":
245
246
247
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
248
        vllm_config = engine_args.create_engine_config(usage_context)
249
        executor_class = Executor.get_class(vllm_config)
250
251
252
253
254

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
255
            log_requests=engine_args.enable_log_requests,
256
257
258
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
259
            stat_loggers=stat_loggers,
260
261
        )

262
263
264
    def __del__(self):
        self.shutdown()

265
266
267
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

268
269
        shutdown_prometheus()

270
271
272
        if renderer := getattr(self, "renderer", None):
            renderer.shutdown()

273
274
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
275

276
277
278
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
279

280
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
281
282
283
284
285
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = await self.engine_core.get_supported_tasks_async()

        return self._supported_tasks
286

287
288
289
    async def add_request(
        self,
        request_id: str,
290
291
        prompt: EngineCoreRequest
        | PromptType
292
        | ProcessorInputs
293
        | AsyncGenerator[StreamingInput, None],
294
295
296
297
298
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
299
        priority: int = 0,
300
301
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
302
        reasoning_ended: bool | None = None,
303
    ) -> RequestOutputCollector:
304
305
        """Add new request to the AsyncLLM."""

306
307
308
        if self.errored:
            raise EngineDeadError()

309
        is_pooling = isinstance(params, PoolingParams)
310

311
312
313
314
315
316
317
318
319
320
321
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

322
        if isinstance(prompt, AsyncGenerator):
323
324
325
            if reasoning_ended is not None:
                raise NotImplementedError

326
327
328
329
330
331
332
333
334
335
336
337
338
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

339
        # Convert Input --> Request.
340
        if isinstance(prompt, EngineCoreRequest):
341
342
343
344
345
346
            logger.warning_once(
                "Passing EngineCoreRequest to AsyncLLM.generate() and .add_requests() "
                "is deprecated and will be removed in v0.18. You should instead pass "
                "the outputs of Renderer.render_cmpl() or Renderer.render_chat()."
            )

347
            request = prompt
348
349
350
351
352
353
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
354
        else:
355
            request = self.input_processor.process_inputs(
356
357
358
                request_id,
                prompt,
                params,
359
                supported_tasks=await self.get_supported_tasks(),
360
361
362
363
364
365
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
366
            )
367
            prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
368

369
370
371
        if reasoning_ended is not None:
            request.reasoning_ended = reasoning_ended

372
373
        self.input_processor.assign_request_id(request)

374
375
376
377
378
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

379
380
381
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

382
383
384
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

385
        if is_pooling or params.n == 1:
386
            await self._add_request(request, prompt_text, None, 0, queue)
387
388
            return queue

389
390
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
391

392
        # Fan out child requests (for n>1).
393
        parent_request = ParentRequest(request)
394
395
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
396
            child_request = request if idx == parent_params.n - 1 else copy(request)
397
            child_request.request_id = request_id
398
            child_request.sampling_params = child_params
399
400
401
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
402
        return queue
403

404
405
406
    async def _add_request(
        self,
        request: EngineCoreRequest,
407
408
        prompt: str | None,
        parent_req: ParentRequest | None,
409
410
411
        index: int,
        queue: RequestOutputCollector,
    ):
412
        # Add the request to OutputProcessor (this process).
413
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
414

415
416
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
417

418
419
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
420

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
436
            supported_tasks=await self.get_supported_tasks(),
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
471
                    # TODO(nick): Avoid re-validating reused sampling parameters
472
473
474
475
476
477
478
479
480
481
482
483
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
484
485
486
                    prompt_text, _, _ = extract_prompt_components(
                        self.model_config, input_chunk.prompt
                    )
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

523
524
525
526
527
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
528
    async def generate(
529
        self,
530
531
        prompt: EngineCoreRequest
        | PromptType
532
        | ProcessorInputs
533
        | AsyncGenerator[StreamingInput, None],
534
535
        sampling_params: SamplingParams,
        request_id: str,
536
        *,
537
538
539
540
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
541
        priority: int = 0,
542
        data_parallel_rank: int | None = None,
543
        reasoning_ended: bool | None = None,
544
545
546
547
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
548
            * 2) Processing the Input.
549
550
551
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

552
553
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
554
555
556
557
558
559
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

560
        q: RequestOutputCollector | None = None
561
        try:
562
563
564
565
566
567
568
569
570
571
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
572
                reasoning_ended=reasoning_ended,
573
            )
574

575
576
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
577
578
            finished = False
            while not finished:
579
580
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
581
                out = q.get_nowait() or await q.get()
582

583
                # Note: both OutputProcessor and EngineCore handle their
584
                # own request cleanup based on finished.
585
                assert isinstance(out, RequestOutput)
586
587
588
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
589

590
        # If the request is disconnected by the client, generate()
591
592
593
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
594
595
            if q is not None:
                await self.abort(q.request_id, internal=True)
596
597
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
598
            raise
599

600
601
602
603
604
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
605

606
        # Request validation error.
607
        except ValueError as e:
608
            if self.log_requests:
609
                logger.info("Request %s failed (bad request): %s.", request_id, e)
610
            raise
611

612
613
614
615
616
617
618
619
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

620
        # Unexpected error in the generate() task (possibly recoverable).
621
        except Exception as e:
622
623
            if q is not None:
                await self.abort(q.request_id, internal=True)
624
            if self.log_requests:
625
626
627
628
629
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
630
                        "error during printing an exception of class"
631
632
633
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
634
            raise EngineGenerateError() from e
635
636
637
        finally:
            if q is not None:
                q.close()
638
639
640
641
642
643
644
645
646
647
648
649

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
650
        logger_manager = self.logger_manager
651
        renderer = self.renderer
652
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
653
654
655
656
657
658
659
660

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

661
662
663
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
664
665
666
667

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
668
669
670
671
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
672
673
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
674
675
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
676
677
678
679
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
680
                        if end < num_outputs:
681
682
683
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
684
685
686
687
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
688

689
690
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

691
692
693
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
694
695
696
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
697
698
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
699
                            mm_cache_stats=renderer.stat_mm_cache(),
700
701
702
703
704
705
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
706

707
708
709
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
710
        """Abort RequestId in OutputProcessor and EngineCore."""
711

712
713
714
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
715
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
716
        await self.engine_core.abort_requests_async(all_request_ids)
717

718
        if self.log_requests:
719
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
720

721
722
723
    async def pause_generation(
        self,
        *,
724
725
        mode: PauseMode = "abort",
        wait_for_inflight_requests: bool | None = None,
726
727
728
729
730
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

731
732
733
        All mode handling (abort / wait / keep) and cache clearing is done
        in the engine. New generation/encoding requests will not be scheduled
        until resume is called.
734
735

        Args:
736
737
738
739
740
741
742
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED: use mode argument.
743
744
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
745
746
747
748
749
750
751
752
753
754
        """
        if wait_for_inflight_requests:
            warnings.warn(
                "The `wait_for_inflight_requests` parameter in "
                "`AsyncLLM.pause_generation()` is deprecated. "
                "Please use `mode` argument instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            mode = "wait"
755
        await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache)
756
757
758
759
760
761
762
        # Small sleep to help ensure that final outputs from any in-flight requests are
        # returned prior to this method returning. These outputs come out of the engine
        # prior to the wait-for-idle completion event, but involve additional async
        # tasks in output processing.
        # Note that this is not required for correctness, just more intuitive ordering
        # of events from caller's pov.
        await asyncio.sleep(0.02)
763
764
765

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""
766
        await self.engine_core.resume_scheduler_async()
767
768
769

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
770
        return await self.engine_core.is_scheduler_paused_async()
771

772
    async def encode(
773
        self,
774
        prompt: PromptType | ProcessorInputs,
775
776
        pooling_params: PoolingParams,
        request_id: str,
777
778
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
779
        priority: int = 0,
780
        tokenization_kwargs: dict[str, Any] | None = None,
781
        reasoning_ended: bool | None = None,
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

797
        q: RequestOutputCollector | None = None
798
799
800
801
802
803
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
804
                tokenization_kwargs=tokenization_kwargs,
805
806
                trace_headers=trace_headers,
                priority=priority,
807
                reasoning_ended=reasoning_ended,
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
826
827
            if q is not None:
                await self.abort(q.request_id, internal=True)
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
846
847
            if q is not None:
                await self.abort(q.request_id, internal=True)
848
849
850
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
851
852
853
        finally:
            if q is not None:
                q.close()
854

855
    @property
856
    def tokenizer(self) -> TokenizerLike | None:
857
        return self.renderer.tokenizer
858

859
    def get_tokenizer(self) -> TokenizerLike:
860
        return self.renderer.get_tokenizer()
861
862

    async def is_tracing_enabled(self) -> bool:
863
        return self.observability_config.otlp_traces_endpoint is not None
864

865
    async def do_log_stats(self) -> None:
866
867
        if self.logger_manager:
            self.logger_manager.log()
868
869
870

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
871
872
        if self.errored:
            raise self.dead_error
873

874
875
    async def start_profile(self, profile_prefix: str | None = None) -> None:
        coros = [self.engine_core.profile_async(True, profile_prefix)]
876
877
878
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
879
880

    async def stop_profile(self) -> None:
881
882
883
884
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
885

886
    async def reset_mm_cache(self) -> None:
887
        self.renderer.clear_mm_cache()
888
889
        await self.engine_core.reset_mm_cache_async()

890
891
892
893
894
895
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
896

897
898
899
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

900
901
    async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
        await self.engine_core.sleep_async(level, mode)
902

903
904
905
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

906
    async def wake_up(self, tags: list[str] | None = None) -> None:
907
        await self.engine_core.wake_up_async(tags)
908

909
910
911
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

912
913
914
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

915
    async def add_lora(self, lora_request: LoRARequest) -> bool:
916
        """Load a new LoRA adapter into the engine for future requests."""
917
918
919
920
921
922
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

923
    async def list_loras(self) -> set[int]:
924
925
926
927
928
929
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
930

931
932
933
    async def collective_rpc(
        self,
        method: str,
934
        timeout: float | None = None,
935
        args: tuple = (),
936
        kwargs: dict | None = None,
937
    ):
938
939
940
941
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
942
943
            method, timeout, args, kwargs
        )
944

945
946
947
948
949
950
951
952
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

953
            logger.info("Engines are still running, waiting for requests to drain...")
954
955
            await asyncio.sleep(1)  # Wait 1 second before checking again

956
957
958
959
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
960

961
962
963
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
964
965
966
967
968
969
970
971
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
972
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
973
        if old_data_parallel_size == new_data_parallel_size:
974
975
976
977
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
978
979
            return
        logger.info(
980
981
982
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
983
984
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
985
986
987
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
988
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
989
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
990
991

        # recreate stat loggers
992
993
994
995
996
997
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
998
                vllm_config=self.vllm_config,
999
                engine_idxs=list(range(new_data_parallel_size)),
1000
1001
1002
                custom_stat_loggers=None,
            )

1003
1004
    @property
    def is_running(self) -> bool:
1005
1006
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1007
1008
1009

    @property
    def is_stopped(self) -> bool:
1010
        return self.errored
1011
1012
1013

    @property
    def errored(self) -> bool:
1014
        return self.engine_core.resources.engine_dead or not self.is_running
1015
1016
1017

    @property
    def dead_error(self) -> BaseException:
1018
        return EngineDeadError()
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059

    async def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        from vllm.distributed.weight_transfer.base import (
            WeightTransferInitRequest,
        )

        if isinstance(request, WeightTransferInitRequest):
            init_info_dict = request.init_info
        else:
            raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")

        await self.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """
        Batched weight update for RL training.

        Args:
            request: Weight update request with backend-specific update info
        """

        if isinstance(request, WeightTransferUpdateRequest):
            update_info_dict = request.update_info
        else:
            raise TypeError(
                f"Expected WeightTransferUpdateRequest, got {type(request)}"
            )

        await self.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )