async_llm.py 41.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
21
from vllm.engine.arg_utils import AsyncEngineArgs
22
from vllm.engine.protocol import EngineClient, StreamingInput
23
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
24
from vllm.inputs import EngineInput, PromptType
25
26
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
27
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
28
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
29
from vllm.pooling_params import PoolingParams
30
from vllm.renderers import renderer_from_config
31
from vllm.renderers.inputs.preprocess import extract_prompt_components
32
from vllm.sampling_params import RequestOutputKind, SamplingParams
33
from vllm.tasks import SupportedTask
34
from vllm.tokenizers import TokenizerLike
35
from vllm.tracing import init_tracer
36
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
37
from vllm.usage.usage_lib import UsageContext
38
39
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
40
from vllm.v1.engine import EngineCoreRequest, PauseMode
41
from vllm.v1.engine.core_client import EngineCoreClient
42
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
43
from vllm.v1.engine.input_processor import InputProcessor
44
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
45
from vllm.v1.engine.parallel_sampling import ParentRequest
46
from vllm.v1.executor import Executor
47
48
49
50
51
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
52
from vllm.v1.metrics.prometheus import shutdown_prometheus
53
from vllm.v1.metrics.stats import IterationStats
54
55
56
57

logger = init_logger(__name__)


58
59
60
61
62
63
64
65
66
67
68
69
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


70
class AsyncLLM(EngineClient):
71
72
    """An asynchronous wrapper for the vLLM engine."""

73
74
75
    def __init__(
        self,
        vllm_config: VllmConfig,
76
        executor_class: type[Executor],
77
78
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
79
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
80
81
        log_requests: bool = True,
        start_engine_loop: bool = True,
82
        stat_loggers: list[StatLoggerFactory] | None = None,
83
        aggregate_engine_logging: bool = False,
84
        client_addresses: dict[str, str] | None = None,
85
        client_count: int = 1,
86
        client_index: int = 0,
87
    ) -> None:
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
107
108
109
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

110
        self.vllm_config = vllm_config
111
        self.model_config = vllm_config.model_config
112
        self.observability_config = vllm_config.observability_config
113

114
115
116
117
        tracing_endpoint = self.observability_config.otlp_traces_endpoint
        if tracing_endpoint is not None:
            init_tracer("vllm.llm_engine", tracing_endpoint)

118
        self.log_requests = log_requests
119

120
121
122
123
124
125
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
126
            logger.info(
127
128
129
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
130
            )
131

132
        self.renderer = renderer = renderer_from_config(self.vllm_config)
133

134
        # Convert EngineInput --> EngineCoreRequest.
135
136
137
        self.input_processor = InputProcessor(self.vllm_config, renderer)

        # Converts EngineCoreOutputs --> RequestOutput.
138
        self.output_processor = OutputProcessor(
139
            renderer.tokenizer,
140
141
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
142
            tracing_enabled=tracing_endpoint is not None,
143
        )
144
145

        # EngineCore (starts the engine in background process).
146
        self.engine_core = EngineCoreClient.make_async_mp_client(
147
148
            vllm_config=vllm_config,
            executor_class=executor_class,
149
            log_stats=self.log_stats,
150
            client_addresses=client_addresses,
151
            client_count=client_count,
152
            client_index=client_index,
153
        )
154
155

        # Loggers.
156
        self.logger_manager: StatLoggerManager | None = None
157
158
159
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
160
                engine_idxs=self.engine_core.engine_ranks_managed,
161
                custom_stat_loggers=custom_stat_loggers,
162
                enable_default_loggers=log_stats,
163
                client_count=client_count,
164
                aggregate_engine_logging=aggregate_engine_logging,
165
166
167
            )
            self.logger_manager.log_engine_initialized()

168
        self._client_count = client_count
169

170
        self.output_handler: asyncio.Task | None = None
171
172
173
174
175
176
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
177

178
        if (
179
180
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
181
        ):
182
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
183
184
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
185
                profiler_dir,
186
            )
187
188
189
190
191
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
192
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
193
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
194
                    profiler_dir,
195
                    worker_name=worker_name,
196
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
197
198
                ),
            )
199
200
201
        else:
            self.profiler = None

202
203
    @classmethod
    def from_vllm_config(
204
205
206
207
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
208
        stat_loggers: list[StatLoggerFactory] | None = None,
209
        enable_log_requests: bool = False,
210
        aggregate_engine_logging: bool = False,
211
        disable_log_stats: bool = False,
212
        client_addresses: dict[str, str] | None = None,
213
214
        client_count: int = 1,
        client_index: int = 0,
215
216
217
218
219
220
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
221
            stat_loggers=stat_loggers,
222
            log_requests=enable_log_requests,
223
            log_stats=not disable_log_stats,
224
            aggregate_engine_logging=aggregate_engine_logging,
225
            usage_context=usage_context,
226
            client_addresses=client_addresses,
227
            client_count=client_count,
228
            client_index=client_index,
229
230
        )

231
232
233
234
235
236
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
237
        stat_loggers: list[StatLoggerFactory] | None = None,
238
    ) -> "AsyncLLM":
239
240
241
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
242
        vllm_config = engine_args.create_engine_config(usage_context)
243
        executor_class = Executor.get_class(vllm_config)
244
245
246
247
248

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
249
            log_requests=engine_args.enable_log_requests,
250
251
252
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
253
            stat_loggers=stat_loggers,
254
255
        )

256
257
258
    def __del__(self):
        self.shutdown()

259
    def shutdown(self, timeout: float | None = None) -> None:
260
        """Shutdown, cleaning up the background proc and IPC."""
261
262
        shutdown_prometheus()

263
264
265
        if renderer := getattr(self, "renderer", None):
            renderer.shutdown()

266
        if engine_core := getattr(self, "engine_core", None):
267
            engine_core.shutdown(timeout=timeout)
268

269
270
271
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
272

273
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
274
275
276
277
278
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = await self.engine_core.get_supported_tasks_async()

        return self._supported_tasks
279

280
281
282
    async def add_request(
        self,
        request_id: str,
283
284
        prompt: EngineCoreRequest
        | PromptType
285
        | EngineInput
286
        | AsyncGenerator[StreamingInput, None],
287
288
289
290
291
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
292
        priority: int = 0,
293
294
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
295
        reasoning_ended: bool | None = None,
296
    ) -> RequestOutputCollector:
297
298
        """Add new request to the AsyncLLM."""

299
300
301
        if self.errored:
            raise EngineDeadError()

302
        is_pooling = isinstance(params, PoolingParams)
303

304
305
306
307
308
309
310
311
312
313
314
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

315
        if isinstance(prompt, AsyncGenerator):
316
317
318
            if reasoning_ended is not None:
                raise NotImplementedError

319
320
321
322
323
324
325
326
327
328
329
330
331
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

332
        # Convert Input --> Request.
333
        if isinstance(prompt, EngineCoreRequest):
334
335
336
337
338
339
            logger.warning_once(
                "Passing EngineCoreRequest to AsyncLLM.generate() and .add_requests() "
                "is deprecated and will be removed in v0.18. You should instead pass "
                "the outputs of Renderer.render_cmpl() or Renderer.render_chat()."
            )

340
            request = prompt
341
342
343
344
345
346
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
347
        else:
348
            request = self.input_processor.process_inputs(
349
350
351
                request_id,
                prompt,
                params,
352
                supported_tasks=await self.get_supported_tasks(),
353
354
355
356
357
358
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
359
            )
360
            prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
361

362
363
364
        if reasoning_ended is not None:
            request.reasoning_ended = reasoning_ended

365
366
        self.input_processor.assign_request_id(request)

367
368
369
370
371
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

372
373
374
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

375
376
377
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

378
        if is_pooling or params.n == 1:
379
            await self._add_request(request, prompt_text, None, 0, queue)
380
381
            return queue

382
383
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
384

385
        # Fan out child requests (for n>1).
386
        parent_request = ParentRequest(request)
387
388
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
389
            child_request = request if idx == parent_params.n - 1 else copy(request)
390
            child_request.request_id = request_id
391
            child_request.sampling_params = child_params
392
393
394
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
395
        return queue
396

397
398
399
    async def _add_request(
        self,
        request: EngineCoreRequest,
400
401
        prompt: str | None,
        parent_req: ParentRequest | None,
402
403
404
        index: int,
        queue: RequestOutputCollector,
    ):
405
        # Add the request to OutputProcessor (this process).
406
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
407

408
409
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
410

411
412
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
429
            supported_tasks=await self.get_supported_tasks(),
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
464
                    # TODO(nick): Avoid re-validating reused sampling parameters
465
466
467
468
469
470
471
472
473
474
475
476
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
477
478
479
                    prompt_text, _, _ = extract_prompt_components(
                        self.model_config, input_chunk.prompt
                    )
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

516
517
518
519
520
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
521
    async def generate(
522
        self,
523
524
        prompt: EngineCoreRequest
        | PromptType
525
        | EngineInput
526
        | AsyncGenerator[StreamingInput, None],
527
528
        sampling_params: SamplingParams,
        request_id: str,
529
        *,
530
531
532
533
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
534
        priority: int = 0,
535
        data_parallel_rank: int | None = None,
536
        reasoning_ended: bool | None = None,
537
538
539
540
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
541
            * 2) Processing the Input.
542
543
544
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

545
546
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
547
548
549
550
551
552
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

553
        q: RequestOutputCollector | None = None
554
        try:
555
556
557
558
559
560
561
562
563
564
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
565
                reasoning_ended=reasoning_ended,
566
            )
567

568
569
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
570
571
            finished = False
            while not finished:
572
573
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
574
                out = q.get_nowait() or await q.get()
575

576
                # Note: both OutputProcessor and EngineCore handle their
577
                # own request cleanup based on finished.
578
                assert isinstance(out, RequestOutput)
579
580
581
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
582

583
        # If the request is disconnected by the client, generate()
584
585
586
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
587
588
            if q is not None:
                await self.abort(q.request_id, internal=True)
589
590
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
591
            raise
592

593
594
595
596
597
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
598

599
        # Request validation error.
600
        except ValueError as e:
601
            if self.log_requests:
602
                logger.info("Request %s failed (bad request): %s.", request_id, e)
603
            raise
604

605
606
607
608
609
610
611
612
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

613
        # Unexpected error in the generate() task (possibly recoverable).
614
        except Exception as e:
615
616
            if q is not None:
                await self.abort(q.request_id, internal=True)
617
            if self.log_requests:
618
619
620
621
622
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
623
                        "error during printing an exception of class"
624
625
626
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
627
            raise EngineGenerateError() from e
628
629
630
        finally:
            if q is not None:
                q.close()
631
632
633
634
635
636
637
638
639
640
641
642

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
643
644
645
646
647
        # We use a mutable list for logger_manager so that it can be updated
        # during elastic EP scaling (see scale_elastic_ep) without creating
        # a circular reference via self.
        self._logger_ref = [self.logger_manager]
        logger_ref = self._logger_ref
648
        renderer = self.renderer
649
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
650
651
652
653
654
655
656
657

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

658
659
660
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
661
662
663
664

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
665
666
667
668
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
669
670
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
671
672
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
673
674
675
676
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
677
                        if end < num_outputs:
678
679
680
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
681
682
683
684
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
685

686
687
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

688
689
690
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
691
692
                    if logger_ref[0]:
                        logger_ref[0].record(
693
                            engine_idx=outputs.engine_index,
694
695
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
696
                            mm_cache_stats=renderer.stat_mm_cache(),
697
698
699
700
701
702
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
703

704
705
706
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
707
        """Abort RequestId in OutputProcessor and EngineCore."""
708

709
710
711
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
712
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
713
        await self.engine_core.abort_requests_async(all_request_ids)
714

715
        if self.log_requests:
716
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
717

718
719
720
    async def pause_generation(
        self,
        *,
721
722
        mode: PauseMode = "abort",
        wait_for_inflight_requests: bool | None = None,
723
724
725
726
727
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

728
729
730
        All mode handling (abort / wait / keep) and cache clearing is done
        in the engine. New generation/encoding requests will not be scheduled
        until resume is called.
731
732

        Args:
733
734
735
736
737
738
739
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED: use mode argument.
740
741
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
742
743
744
745
746
747
748
749
750
751
        """
        if wait_for_inflight_requests:
            warnings.warn(
                "The `wait_for_inflight_requests` parameter in "
                "`AsyncLLM.pause_generation()` is deprecated. "
                "Please use `mode` argument instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            mode = "wait"
752
        await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache)
753
754
755
756
757
758
759
        # Small sleep to help ensure that final outputs from any in-flight requests are
        # returned prior to this method returning. These outputs come out of the engine
        # prior to the wait-for-idle completion event, but involve additional async
        # tasks in output processing.
        # Note that this is not required for correctness, just more intuitive ordering
        # of events from caller's pov.
        await asyncio.sleep(0.02)
760
761
762

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""
763
        await self.engine_core.resume_scheduler_async()
764
765
766

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
767
        return await self.engine_core.is_scheduler_paused_async()
768

769
    async def encode(
770
        self,
771
        prompt: PromptType | EngineInput,
772
773
        pooling_params: PoolingParams,
        request_id: str,
774
775
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
776
        priority: int = 0,
777
        tokenization_kwargs: dict[str, Any] | None = None,
778
        reasoning_ended: bool | None = None,
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

794
        q: RequestOutputCollector | None = None
795
796
797
798
799
800
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
801
                tokenization_kwargs=tokenization_kwargs,
802
803
                trace_headers=trace_headers,
                priority=priority,
804
                reasoning_ended=reasoning_ended,
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
823
824
            if q is not None:
                await self.abort(q.request_id, internal=True)
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
843
844
            if q is not None:
                await self.abort(q.request_id, internal=True)
845
846
847
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
848
849
850
        finally:
            if q is not None:
                q.close()
851

852
    @property
853
    def tokenizer(self) -> TokenizerLike | None:
854
        return self.renderer.tokenizer
855

856
    def get_tokenizer(self) -> TokenizerLike:
857
        return self.renderer.get_tokenizer()
858
859

    async def is_tracing_enabled(self) -> bool:
860
        return self.observability_config.otlp_traces_endpoint is not None
861

862
    async def do_log_stats(self) -> None:
863
864
        if self.logger_manager:
            self.logger_manager.log()
865
866
867

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
868
869
        if self.errored:
            raise self.dead_error
870

871
872
    async def start_profile(self, profile_prefix: str | None = None) -> None:
        coros = [self.engine_core.profile_async(True, profile_prefix)]
873
874
875
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
876
877

    async def stop_profile(self) -> None:
878
879
880
881
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
882

883
    async def reset_mm_cache(self) -> None:
884
        await self.renderer.clear_mm_cache_async()
885
886
        await self.engine_core.reset_mm_cache_async()

887
888
889
890
891
892
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
893

894
895
896
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

897
898
    async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
        await self.engine_core.sleep_async(level, mode)
899

900
901
902
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

903
    async def wake_up(self, tags: list[str] | None = None) -> None:
904
        await self.engine_core.wake_up_async(tags)
905

906
907
908
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

909
910
911
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

912
    async def add_lora(self, lora_request: LoRARequest) -> bool:
913
        """Load a new LoRA adapter into the engine for future requests."""
914
915
916
917
918
919
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

920
    async def list_loras(self) -> set[int]:
921
922
923
924
925
926
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
927

928
929
930
    async def collective_rpc(
        self,
        method: str,
931
        timeout: float | None = None,
932
        args: tuple = (),
933
        kwargs: dict | None = None,
934
    ):
935
936
937
938
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
939
940
            method, timeout, args, kwargs
        )
941

942
943
944
945
946
947
948
949
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

950
            logger.info("Engines are still running, waiting for requests to drain...")
951
952
            await asyncio.sleep(1)  # Wait 1 second before checking again

953
954
955
956
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
957

958
959
960
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
961
962
963
964
965
966
967
968
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
969
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
970
        if old_data_parallel_size == new_data_parallel_size:
971
972
973
974
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
975
            return
976
977
978
979
980
981
982

        if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
            logger.info(
                "VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
                "waiting for requests to drain before scaling"
            )
            await self.wait_for_requests_to_drain(drain_timeout)
983
984

        # recreate stat loggers
985
986
987
988
989
990
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
991
                vllm_config=self.vllm_config,
992
                engine_idxs=list(range(new_data_parallel_size)),
993
994
                custom_stat_loggers=None,
            )
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
            # Update the mutable ref so output_handler picks up the
            # new logger without creating a circular reference via self.
            if hasattr(self, "_logger_ref"):
                self._logger_ref[0] = self.logger_manager
            self.logger_manager.log_engine_initialized()

        set_scaling_elastic_ep(True)
        try:
            await self.engine_core.scale_elastic_ep(new_data_parallel_size)
            self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
        finally:
            set_scaling_elastic_ep(False)
1007

1008
1009
    @property
    def is_running(self) -> bool:
1010
1011
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1012
1013
1014

    @property
    def is_stopped(self) -> bool:
1015
        return self.errored
1016
1017
1018

    @property
    def errored(self) -> bool:
1019
        return self.engine_core.resources.engine_dead or not self.is_running
1020
1021
1022

    @property
    def dead_error(self) -> BaseException:
1023
        return EngineDeadError()
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064

    async def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        from vllm.distributed.weight_transfer.base import (
            WeightTransferInitRequest,
        )

        if isinstance(request, WeightTransferInitRequest):
            init_info_dict = request.init_info
        else:
            raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")

        await self.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """
        Batched weight update for RL training.

        Args:
            request: Weight update request with backend-specific update info
        """

        if isinstance(request, WeightTransferUpdateRequest):
            update_info_dict = request.update_info
        else:
            raise TypeError(
                f"Expected WeightTransferUpdateRequest, got {type(request)}"
            )

        await self.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )