async_llm.py 41.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
21
from vllm.engine.arg_utils import AsyncEngineArgs
22
from vllm.engine.protocol import EngineClient, StreamingInput
23
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
24
from vllm.inputs import EngineInput, PromptType
25
26
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
27
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
28
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
29
from vllm.plugins.io_processors import get_io_processor
30
from vllm.pooling_params import PoolingParams
31
from vllm.renderers import renderer_from_config
32
from vllm.renderers.inputs.preprocess import extract_prompt_components
33
from vllm.sampling_params import RequestOutputKind, SamplingParams
34
from vllm.tasks import SupportedTask
35
from vllm.tokenizers import TokenizerLike
36
from vllm.tracing import init_tracer
37
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
38
from vllm.usage.usage_lib import UsageContext
39
40
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
41
from vllm.v1.engine import EngineCoreRequest, PauseMode
42
from vllm.v1.engine.core_client import EngineCoreClient
43
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
44
from vllm.v1.engine.input_processor import InputProcessor
45
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
46
from vllm.v1.engine.parallel_sampling import ParentRequest
47
from vllm.v1.executor import Executor
48
49
50
51
52
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
53
from vllm.v1.metrics.prometheus import shutdown_prometheus
54
from vllm.v1.metrics.stats import IterationStats
55
56
57
58

logger = init_logger(__name__)


59
60
61
62
63
64
65
66
67
68
69
70
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


71
class AsyncLLM(EngineClient):
72
73
    """An asynchronous wrapper for the vLLM engine."""

74
75
76
    def __init__(
        self,
        vllm_config: VllmConfig,
77
        executor_class: type[Executor],
78
79
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
80
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
81
82
83
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
84
        stat_loggers: list[StatLoggerFactory] | None = None,
85
        aggregate_engine_logging: bool = False,
86
        client_addresses: dict[str, str] | None = None,
87
        client_count: int = 1,
88
        client_index: int = 0,
89
    ) -> None:
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
110
111
112
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

113
        self.vllm_config = vllm_config
114
        self.model_config = vllm_config.model_config
115
        self.observability_config = vllm_config.observability_config
116

117
118
119
120
        tracing_endpoint = self.observability_config.otlp_traces_endpoint
        if tracing_endpoint is not None:
            init_tracer("vllm.llm_engine", tracing_endpoint)

121
        self.log_requests = log_requests
122

123
124
125
126
127
128
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
129
            logger.info(
130
131
132
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
133
            )
134

135
        self.renderer = renderer = renderer_from_config(self.vllm_config)
136
137
        self.io_processor = get_io_processor(
            self.vllm_config,
138
            self.renderer,
139
            self.model_config.io_processor_plugin,
140
        )
141

142
        # Convert EngineInput --> EngineCoreRequest.
143
144
145
        self.input_processor = InputProcessor(self.vllm_config, renderer)

        # Converts EngineCoreOutputs --> RequestOutput.
146
        self.output_processor = OutputProcessor(
147
            renderer.tokenizer,
148
149
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
150
            tracing_enabled=tracing_endpoint is not None,
151
        )
152
153

        # EngineCore (starts the engine in background process).
154
        self.engine_core = EngineCoreClient.make_async_mp_client(
155
156
            vllm_config=vllm_config,
            executor_class=executor_class,
157
            log_stats=self.log_stats,
158
            client_addresses=client_addresses,
159
            client_count=client_count,
160
            client_index=client_index,
161
        )
162
163

        # Loggers.
164
        self.logger_manager: StatLoggerManager | None = None
165
166
167
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
168
                engine_idxs=self.engine_core.engine_ranks_managed,
169
                custom_stat_loggers=custom_stat_loggers,
170
                enable_default_loggers=log_stats,
171
                client_count=client_count,
172
                aggregate_engine_logging=aggregate_engine_logging,
173
174
175
            )
            self.logger_manager.log_engine_initialized()

176
        self._client_count = client_count
177

178
        self.output_handler: asyncio.Task | None = None
179
180
181
182
183
184
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
185

186
        if (
187
188
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
189
        ):
190
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
191
192
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
193
                profiler_dir,
194
            )
195
196
197
198
199
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
200
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
201
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
202
                    profiler_dir,
203
                    worker_name=worker_name,
204
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
205
206
                ),
            )
207
208
209
        else:
            self.profiler = None

210
211
    @classmethod
    def from_vllm_config(
212
213
214
215
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
216
        stat_loggers: list[StatLoggerFactory] | None = None,
217
        enable_log_requests: bool = False,
218
        aggregate_engine_logging: bool = False,
219
        disable_log_stats: bool = False,
220
        client_addresses: dict[str, str] | None = None,
221
222
        client_count: int = 1,
        client_index: int = 0,
223
224
225
226
227
228
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
229
            stat_loggers=stat_loggers,
230
            log_requests=enable_log_requests,
231
            log_stats=not disable_log_stats,
232
            aggregate_engine_logging=aggregate_engine_logging,
233
            usage_context=usage_context,
234
            client_addresses=client_addresses,
235
            client_count=client_count,
236
            client_index=client_index,
237
238
        )

239
240
241
242
243
244
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
245
        stat_loggers: list[StatLoggerFactory] | None = None,
246
    ) -> "AsyncLLM":
247
248
249
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
250
        vllm_config = engine_args.create_engine_config(usage_context)
251
        executor_class = Executor.get_class(vllm_config)
252
253
254
255
256

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
257
            log_requests=engine_args.enable_log_requests,
258
259
260
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
261
            stat_loggers=stat_loggers,
262
263
        )

264
265
266
    def __del__(self):
        self.shutdown()

267
    def shutdown(self, timeout: float | None = None) -> None:
268
        """Shutdown, cleaning up the background proc and IPC."""
269
270
        shutdown_prometheus()

271
272
273
        if renderer := getattr(self, "renderer", None):
            renderer.shutdown()

274
        if engine_core := getattr(self, "engine_core", None):
275
            engine_core.shutdown(timeout=timeout)
276

277
278
279
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
280

281
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
282
283
284
285
286
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = await self.engine_core.get_supported_tasks_async()

        return self._supported_tasks
287

288
289
290
    async def add_request(
        self,
        request_id: str,
291
292
        prompt: EngineCoreRequest
        | PromptType
293
        | EngineInput
294
        | AsyncGenerator[StreamingInput, None],
295
296
297
298
299
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
300
        priority: int = 0,
301
302
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
303
        reasoning_ended: bool | None = None,
304
    ) -> RequestOutputCollector:
305
306
        """Add new request to the AsyncLLM."""

307
308
309
        if self.errored:
            raise EngineDeadError()

310
        is_pooling = isinstance(params, PoolingParams)
311

312
313
314
315
316
317
318
319
320
321
322
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

323
        if isinstance(prompt, AsyncGenerator):
324
325
326
            if reasoning_ended is not None:
                raise NotImplementedError

327
328
329
330
331
332
333
334
335
336
337
338
339
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

340
        # Convert Input --> Request.
341
        if isinstance(prompt, EngineCoreRequest):
342
343
344
345
346
347
            logger.warning_once(
                "Passing EngineCoreRequest to AsyncLLM.generate() and .add_requests() "
                "is deprecated and will be removed in v0.18. You should instead pass "
                "the outputs of Renderer.render_cmpl() or Renderer.render_chat()."
            )

348
            request = prompt
349
350
351
352
353
354
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
355
        else:
356
            request = self.input_processor.process_inputs(
357
358
359
                request_id,
                prompt,
                params,
360
                supported_tasks=await self.get_supported_tasks(),
361
362
363
364
365
366
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
367
            )
368
            prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
369

370
371
372
        if reasoning_ended is not None:
            request.reasoning_ended = reasoning_ended

373
374
        self.input_processor.assign_request_id(request)

375
376
377
378
379
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

380
381
382
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

383
384
385
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

386
        if is_pooling or params.n == 1:
387
            await self._add_request(request, prompt_text, None, 0, queue)
388
389
            return queue

390
391
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
392

393
        # Fan out child requests (for n>1).
394
        parent_request = ParentRequest(request)
395
396
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
397
            child_request = request if idx == parent_params.n - 1 else copy(request)
398
            child_request.request_id = request_id
399
            child_request.sampling_params = child_params
400
401
402
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
403
        return queue
404

405
406
407
    async def _add_request(
        self,
        request: EngineCoreRequest,
408
409
        prompt: str | None,
        parent_req: ParentRequest | None,
410
411
412
        index: int,
        queue: RequestOutputCollector,
    ):
413
        # Add the request to OutputProcessor (this process).
414
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
415

416
417
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
418

419
420
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
421

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
437
            supported_tasks=await self.get_supported_tasks(),
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
472
                    # TODO(nick): Avoid re-validating reused sampling parameters
473
474
475
476
477
478
479
480
481
482
483
484
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
485
486
487
                    prompt_text, _, _ = extract_prompt_components(
                        self.model_config, input_chunk.prompt
                    )
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

524
525
526
527
528
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
529
    async def generate(
530
        self,
531
532
        prompt: EngineCoreRequest
        | PromptType
533
        | EngineInput
534
        | AsyncGenerator[StreamingInput, None],
535
536
        sampling_params: SamplingParams,
        request_id: str,
537
        *,
538
539
540
541
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
542
        priority: int = 0,
543
        data_parallel_rank: int | None = None,
544
        reasoning_ended: bool | None = None,
545
546
547
548
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
549
            * 2) Processing the Input.
550
551
552
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

553
554
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
555
556
557
558
559
560
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

561
        q: RequestOutputCollector | None = None
562
        try:
563
564
565
566
567
568
569
570
571
572
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
573
                reasoning_ended=reasoning_ended,
574
            )
575

576
577
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
578
579
            finished = False
            while not finished:
580
581
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
582
                out = q.get_nowait() or await q.get()
583

584
                # Note: both OutputProcessor and EngineCore handle their
585
                # own request cleanup based on finished.
586
                assert isinstance(out, RequestOutput)
587
588
589
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
590

591
        # If the request is disconnected by the client, generate()
592
593
594
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
595
596
            if q is not None:
                await self.abort(q.request_id, internal=True)
597
598
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
599
            raise
600

601
602
603
604
605
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
606

607
        # Request validation error.
608
        except ValueError as e:
609
            if self.log_requests:
610
                logger.info("Request %s failed (bad request): %s.", request_id, e)
611
            raise
612

613
614
615
616
617
618
619
620
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

621
        # Unexpected error in the generate() task (possibly recoverable).
622
        except Exception as e:
623
624
            if q is not None:
                await self.abort(q.request_id, internal=True)
625
            if self.log_requests:
626
627
628
629
630
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
631
                        "error during printing an exception of class"
632
633
634
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
635
            raise EngineGenerateError() from e
636
637
638
        finally:
            if q is not None:
                q.close()
639
640
641
642
643
644
645
646
647
648
649
650

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
651
652
653
654
655
        # We use a mutable list for logger_manager so that it can be updated
        # during elastic EP scaling (see scale_elastic_ep) without creating
        # a circular reference via self.
        self._logger_ref = [self.logger_manager]
        logger_ref = self._logger_ref
656
        renderer = self.renderer
657
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
658
659
660
661
662
663
664
665

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

666
667
668
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
669
670
671
672

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
673
674
675
676
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
677
678
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
679
680
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
681
682
683
684
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
685
                        if end < num_outputs:
686
687
688
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
689
690
691
692
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
693

694
695
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

696
697
698
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
699
700
                    if logger_ref[0]:
                        logger_ref[0].record(
701
                            engine_idx=outputs.engine_index,
702
703
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
704
                            mm_cache_stats=renderer.stat_mm_cache(),
705
706
707
708
709
710
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
711

712
713
714
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
715
        """Abort RequestId in OutputProcessor and EngineCore."""
716

717
718
719
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
720
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
721
        await self.engine_core.abort_requests_async(all_request_ids)
722

723
        if self.log_requests:
724
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
725

726
727
728
    async def pause_generation(
        self,
        *,
729
730
        mode: PauseMode = "abort",
        wait_for_inflight_requests: bool | None = None,
731
732
733
734
735
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

736
737
738
        All mode handling (abort / wait / keep) and cache clearing is done
        in the engine. New generation/encoding requests will not be scheduled
        until resume is called.
739
740

        Args:
741
742
743
744
745
746
747
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED: use mode argument.
748
749
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
750
751
752
753
754
755
756
757
758
759
        """
        if wait_for_inflight_requests:
            warnings.warn(
                "The `wait_for_inflight_requests` parameter in "
                "`AsyncLLM.pause_generation()` is deprecated. "
                "Please use `mode` argument instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            mode = "wait"
760
        await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache)
761
762
763
764
765
766
767
        # Small sleep to help ensure that final outputs from any in-flight requests are
        # returned prior to this method returning. These outputs come out of the engine
        # prior to the wait-for-idle completion event, but involve additional async
        # tasks in output processing.
        # Note that this is not required for correctness, just more intuitive ordering
        # of events from caller's pov.
        await asyncio.sleep(0.02)
768
769
770

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""
771
        await self.engine_core.resume_scheduler_async()
772
773
774

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
775
        return await self.engine_core.is_scheduler_paused_async()
776

777
    async def encode(
778
        self,
779
        prompt: PromptType | EngineInput,
780
781
        pooling_params: PoolingParams,
        request_id: str,
782
783
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
784
        priority: int = 0,
785
        tokenization_kwargs: dict[str, Any] | None = None,
786
        reasoning_ended: bool | None = None,
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

802
        q: RequestOutputCollector | None = None
803
804
805
806
807
808
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
809
                tokenization_kwargs=tokenization_kwargs,
810
811
                trace_headers=trace_headers,
                priority=priority,
812
                reasoning_ended=reasoning_ended,
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
831
832
            if q is not None:
                await self.abort(q.request_id, internal=True)
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
851
852
            if q is not None:
                await self.abort(q.request_id, internal=True)
853
854
855
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
856
857
858
        finally:
            if q is not None:
                q.close()
859

860
    @property
861
    def tokenizer(self) -> TokenizerLike | None:
862
        return self.renderer.tokenizer
863

864
    def get_tokenizer(self) -> TokenizerLike:
865
        return self.renderer.get_tokenizer()
866
867

    async def is_tracing_enabled(self) -> bool:
868
        return self.observability_config.otlp_traces_endpoint is not None
869

870
    async def do_log_stats(self) -> None:
871
872
        if self.logger_manager:
            self.logger_manager.log()
873
874
875

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
876
877
        if self.errored:
            raise self.dead_error
878

879
880
    async def start_profile(self, profile_prefix: str | None = None) -> None:
        coros = [self.engine_core.profile_async(True, profile_prefix)]
881
882
883
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
884
885

    async def stop_profile(self) -> None:
886
887
888
889
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
890

891
    async def reset_mm_cache(self) -> None:
892
        self.renderer.clear_mm_cache()
893
894
        await self.engine_core.reset_mm_cache_async()

895
896
897
898
899
900
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
901

902
903
904
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

905
906
    async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
        await self.engine_core.sleep_async(level, mode)
907

908
909
910
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

911
    async def wake_up(self, tags: list[str] | None = None) -> None:
912
        await self.engine_core.wake_up_async(tags)
913

914
915
916
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

917
918
919
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

920
    async def add_lora(self, lora_request: LoRARequest) -> bool:
921
        """Load a new LoRA adapter into the engine for future requests."""
922
923
924
925
926
927
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

928
    async def list_loras(self) -> set[int]:
929
930
931
932
933
934
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
935

936
937
938
    async def collective_rpc(
        self,
        method: str,
939
        timeout: float | None = None,
940
        args: tuple = (),
941
        kwargs: dict | None = None,
942
    ):
943
944
945
946
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
947
948
            method, timeout, args, kwargs
        )
949

950
951
952
953
954
955
956
957
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

958
            logger.info("Engines are still running, waiting for requests to drain...")
959
960
            await asyncio.sleep(1)  # Wait 1 second before checking again

961
962
963
964
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
965

966
967
968
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
969
970
971
972
973
974
975
976
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
977
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
978
        if old_data_parallel_size == new_data_parallel_size:
979
980
981
982
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
983
            return
984
985
986
987
988
989
990

        if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
            logger.info(
                "VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
                "waiting for requests to drain before scaling"
            )
            await self.wait_for_requests_to_drain(drain_timeout)
991
992

        # recreate stat loggers
993
994
995
996
997
998
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
999
                vllm_config=self.vllm_config,
1000
                engine_idxs=list(range(new_data_parallel_size)),
1001
1002
                custom_stat_loggers=None,
            )
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
            # Update the mutable ref so output_handler picks up the
            # new logger without creating a circular reference via self.
            if hasattr(self, "_logger_ref"):
                self._logger_ref[0] = self.logger_manager
            self.logger_manager.log_engine_initialized()

        set_scaling_elastic_ep(True)
        try:
            await self.engine_core.scale_elastic_ep(new_data_parallel_size)
            self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
        finally:
            set_scaling_elastic_ep(False)
1015

1016
1017
    @property
    def is_running(self) -> bool:
1018
1019
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1020
1021
1022

    @property
    def is_stopped(self) -> bool:
1023
        return self.errored
1024
1025
1026

    @property
    def errored(self) -> bool:
1027
        return self.engine_core.resources.engine_dead or not self.is_running
1028
1029
1030

    @property
    def dead_error(self) -> BaseException:
1031
        return EngineDeadError()
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072

    async def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        from vllm.distributed.weight_transfer.base import (
            WeightTransferInitRequest,
        )

        if isinstance(request, WeightTransferInitRequest):
            init_info_dict = request.init_info
        else:
            raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")

        await self.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """
        Batched weight update for RL training.

        Args:
            request: Weight update request with backend-specific update info
        """

        if isinstance(request, WeightTransferUpdateRequest):
            update_info_dict = request.update_info
        else:
            raise TypeError(
                f"Expected WeightTransferUpdateRequest, got {type(request)}"
            )

        await self.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )