async_llm.py 41.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
21
from vllm.engine.arg_utils import AsyncEngineArgs
22
from vllm.engine.protocol import EngineClient, StreamingInput
23
from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep
24
from vllm.inputs import EngineInput, PromptType
25
26
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
27
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
28
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
29
from vllm.pooling_params import PoolingParams
30
from vllm.renderers import renderer_from_config
31
from vllm.renderers.inputs.preprocess import extract_prompt_components
32
from vllm.sampling_params import RequestOutputKind, SamplingParams
33
from vllm.tasks import SupportedTask
34
from vllm.tokenizers import TokenizerLike
35
from vllm.tracing import init_tracer
36
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
37
from vllm.usage.usage_lib import UsageContext
38
39
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
40
from vllm.v1.engine import EngineCoreRequest, PauseMode
41
from vllm.v1.engine.core_client import EngineCoreClient
42
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
43
from vllm.v1.engine.input_processor import InputProcessor
44
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
45
from vllm.v1.engine.parallel_sampling import ParentRequest
46
from vllm.v1.executor import Executor
47
48
49
50
51
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
52
from vllm.v1.metrics.prometheus import shutdown_prometheus
53
from vllm.v1.metrics.stats import IterationStats
54
55
56
57

logger = init_logger(__name__)


58
59
60
61
62
63
64
65
66
67
68
69
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


70
class AsyncLLM(EngineClient):
71
72
    """An asynchronous wrapper for the vLLM engine."""

73
74
75
    def __init__(
        self,
        vllm_config: VllmConfig,
76
        executor_class: type[Executor],
77
78
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
79
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
80
81
82
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
83
        stat_loggers: list[StatLoggerFactory] | None = None,
84
        aggregate_engine_logging: bool = False,
85
        client_addresses: dict[str, str] | None = None,
86
        client_count: int = 1,
87
        client_index: int = 0,
88
    ) -> None:
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
109
110
111
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

112
        self.vllm_config = vllm_config
113
        self.model_config = vllm_config.model_config
114
        self.observability_config = vllm_config.observability_config
115

116
117
118
119
        tracing_endpoint = self.observability_config.otlp_traces_endpoint
        if tracing_endpoint is not None:
            init_tracer("vllm.llm_engine", tracing_endpoint)

120
        self.log_requests = log_requests
121

122
123
124
125
126
127
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
128
            logger.info(
129
130
131
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
132
            )
133

134
        self.renderer = renderer = renderer_from_config(self.vllm_config)
135

136
        # Convert EngineInput --> EngineCoreRequest.
137
138
139
        self.input_processor = InputProcessor(self.vllm_config, renderer)

        # Converts EngineCoreOutputs --> RequestOutput.
140
        self.output_processor = OutputProcessor(
141
            renderer.tokenizer,
142
143
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
144
            tracing_enabled=tracing_endpoint is not None,
145
        )
146
147

        # EngineCore (starts the engine in background process).
148
        self.engine_core = EngineCoreClient.make_async_mp_client(
149
150
            vllm_config=vllm_config,
            executor_class=executor_class,
151
            log_stats=self.log_stats,
152
            client_addresses=client_addresses,
153
            client_count=client_count,
154
            client_index=client_index,
155
        )
156
157

        # Loggers.
158
        self.logger_manager: StatLoggerManager | None = None
159
160
161
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
162
                engine_idxs=self.engine_core.engine_ranks_managed,
163
                custom_stat_loggers=custom_stat_loggers,
164
                enable_default_loggers=log_stats,
165
                client_count=client_count,
166
                aggregate_engine_logging=aggregate_engine_logging,
167
168
169
            )
            self.logger_manager.log_engine_initialized()

170
        self._client_count = client_count
171

172
        self.output_handler: asyncio.Task | None = None
173
174
175
176
177
178
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
179

180
        if (
181
182
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
183
        ):
184
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
185
186
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
187
                profiler_dir,
188
            )
189
190
191
192
193
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
194
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
195
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
196
                    profiler_dir,
197
                    worker_name=worker_name,
198
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
199
200
                ),
            )
201
202
203
        else:
            self.profiler = None

204
205
    @classmethod
    def from_vllm_config(
206
207
208
209
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
210
        stat_loggers: list[StatLoggerFactory] | None = None,
211
        enable_log_requests: bool = False,
212
        aggregate_engine_logging: bool = False,
213
        disable_log_stats: bool = False,
214
        client_addresses: dict[str, str] | None = None,
215
216
        client_count: int = 1,
        client_index: int = 0,
217
218
219
220
221
222
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
223
            stat_loggers=stat_loggers,
224
            log_requests=enable_log_requests,
225
            log_stats=not disable_log_stats,
226
            aggregate_engine_logging=aggregate_engine_logging,
227
            usage_context=usage_context,
228
            client_addresses=client_addresses,
229
            client_count=client_count,
230
            client_index=client_index,
231
232
        )

233
234
235
236
237
238
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
239
        stat_loggers: list[StatLoggerFactory] | None = None,
240
    ) -> "AsyncLLM":
241
242
243
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
244
        vllm_config = engine_args.create_engine_config(usage_context)
245
        executor_class = Executor.get_class(vllm_config)
246
247
248
249
250

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
251
            log_requests=engine_args.enable_log_requests,
252
253
254
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
255
            stat_loggers=stat_loggers,
256
257
        )

258
259
260
    def __del__(self):
        self.shutdown()

261
    def shutdown(self, timeout: float | None = None) -> None:
262
        """Shutdown, cleaning up the background proc and IPC."""
263
264
        shutdown_prometheus()

265
266
267
        if renderer := getattr(self, "renderer", None):
            renderer.shutdown()

268
        if engine_core := getattr(self, "engine_core", None):
269
            engine_core.shutdown(timeout=timeout)
270

271
272
273
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
274

275
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
276
277
278
279
280
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = await self.engine_core.get_supported_tasks_async()

        return self._supported_tasks
281

282
283
284
    async def add_request(
        self,
        request_id: str,
285
286
        prompt: EngineCoreRequest
        | PromptType
287
        | EngineInput
288
        | AsyncGenerator[StreamingInput, None],
289
290
291
292
293
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
294
        priority: int = 0,
295
296
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
297
        reasoning_ended: bool | None = None,
298
    ) -> RequestOutputCollector:
299
300
        """Add new request to the AsyncLLM."""

301
302
303
        if self.errored:
            raise EngineDeadError()

304
        is_pooling = isinstance(params, PoolingParams)
305

306
307
308
309
310
311
312
313
314
315
316
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

317
        if isinstance(prompt, AsyncGenerator):
318
319
320
            if reasoning_ended is not None:
                raise NotImplementedError

321
322
323
324
325
326
327
328
329
330
331
332
333
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

334
        # Convert Input --> Request.
335
        if isinstance(prompt, EngineCoreRequest):
336
337
338
339
340
341
            logger.warning_once(
                "Passing EngineCoreRequest to AsyncLLM.generate() and .add_requests() "
                "is deprecated and will be removed in v0.18. You should instead pass "
                "the outputs of Renderer.render_cmpl() or Renderer.render_chat()."
            )

342
            request = prompt
343
344
345
346
347
348
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
349
        else:
350
            request = self.input_processor.process_inputs(
351
352
353
                request_id,
                prompt,
                params,
354
                supported_tasks=await self.get_supported_tasks(),
355
356
357
358
359
360
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
361
            )
362
            prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
363

364
365
366
        if reasoning_ended is not None:
            request.reasoning_ended = reasoning_ended

367
368
        self.input_processor.assign_request_id(request)

369
370
371
372
373
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

374
375
376
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

377
378
379
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

380
        if is_pooling or params.n == 1:
381
            await self._add_request(request, prompt_text, None, 0, queue)
382
383
            return queue

384
385
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
386

387
        # Fan out child requests (for n>1).
388
        parent_request = ParentRequest(request)
389
390
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
391
            child_request = request if idx == parent_params.n - 1 else copy(request)
392
            child_request.request_id = request_id
393
            child_request.sampling_params = child_params
394
395
396
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
397
        return queue
398

399
400
401
    async def _add_request(
        self,
        request: EngineCoreRequest,
402
403
        prompt: str | None,
        parent_req: ParentRequest | None,
404
405
406
        index: int,
        queue: RequestOutputCollector,
    ):
407
        # Add the request to OutputProcessor (this process).
408
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
409

410
411
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
412

413
414
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
415

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
431
            supported_tasks=await self.get_supported_tasks(),
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
466
                    # TODO(nick): Avoid re-validating reused sampling parameters
467
468
469
470
471
472
473
474
475
476
477
478
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
479
480
481
                    prompt_text, _, _ = extract_prompt_components(
                        self.model_config, input_chunk.prompt
                    )
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

518
519
520
521
522
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
523
    async def generate(
524
        self,
525
526
        prompt: EngineCoreRequest
        | PromptType
527
        | EngineInput
528
        | AsyncGenerator[StreamingInput, None],
529
530
        sampling_params: SamplingParams,
        request_id: str,
531
        *,
532
533
534
535
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
536
        priority: int = 0,
537
        data_parallel_rank: int | None = None,
538
        reasoning_ended: bool | None = None,
539
540
541
542
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
543
            * 2) Processing the Input.
544
545
546
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

547
548
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
549
550
551
552
553
554
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

555
        q: RequestOutputCollector | None = None
556
        try:
557
558
559
560
561
562
563
564
565
566
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
567
                reasoning_ended=reasoning_ended,
568
            )
569

570
571
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
572
573
            finished = False
            while not finished:
574
575
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
576
                out = q.get_nowait() or await q.get()
577

578
                # Note: both OutputProcessor and EngineCore handle their
579
                # own request cleanup based on finished.
580
                assert isinstance(out, RequestOutput)
581
582
583
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
584

585
        # If the request is disconnected by the client, generate()
586
587
588
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
589
590
            if q is not None:
                await self.abort(q.request_id, internal=True)
591
592
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
593
            raise
594

595
596
597
598
599
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
600

601
        # Request validation error.
602
        except ValueError as e:
603
            if self.log_requests:
604
                logger.info("Request %s failed (bad request): %s.", request_id, e)
605
            raise
606

607
608
609
610
611
612
613
614
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

615
        # Unexpected error in the generate() task (possibly recoverable).
616
        except Exception as e:
617
618
            if q is not None:
                await self.abort(q.request_id, internal=True)
619
            if self.log_requests:
620
621
622
623
624
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
625
                        "error during printing an exception of class"
626
627
628
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
629
            raise EngineGenerateError() from e
630
631
632
        finally:
            if q is not None:
                q.close()
633
634
635
636
637
638
639
640
641
642
643
644

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
645
646
647
648
649
        # We use a mutable list for logger_manager so that it can be updated
        # during elastic EP scaling (see scale_elastic_ep) without creating
        # a circular reference via self.
        self._logger_ref = [self.logger_manager]
        logger_ref = self._logger_ref
650
        renderer = self.renderer
651
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
652
653
654
655
656
657
658
659

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

660
661
662
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
663
664
665
666

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
667
668
669
670
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
671
672
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
673
674
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
675
676
677
678
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
679
                        if end < num_outputs:
680
681
682
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
683
684
685
686
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
687

688
689
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

690
691
692
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
693
694
                    if logger_ref[0]:
                        logger_ref[0].record(
695
                            engine_idx=outputs.engine_index,
696
697
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
698
                            mm_cache_stats=renderer.stat_mm_cache(),
699
700
701
702
703
704
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
705

706
707
708
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
709
        """Abort RequestId in OutputProcessor and EngineCore."""
710

711
712
713
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
714
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
715
        await self.engine_core.abort_requests_async(all_request_ids)
716

717
        if self.log_requests:
718
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
719

720
721
722
    async def pause_generation(
        self,
        *,
723
724
        mode: PauseMode = "abort",
        wait_for_inflight_requests: bool | None = None,
725
726
727
728
729
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

730
731
732
        All mode handling (abort / wait / keep) and cache clearing is done
        in the engine. New generation/encoding requests will not be scheduled
        until resume is called.
733
734

        Args:
735
736
737
738
739
740
741
            mode: How to handle in-flight requests:
                - ``"abort"``: Abort all in-flight requests immediately
                  (default).
                - ``"wait"``: Wait for in-flight requests to complete.
                - ``"keep"``: Freeze requests in queue; they resume on
                  :meth:`resume_generation`.
            wait_for_inflight_requests: DEPRECATED: use mode argument.
742
743
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
744
745
746
747
748
749
750
751
752
753
        """
        if wait_for_inflight_requests:
            warnings.warn(
                "The `wait_for_inflight_requests` parameter in "
                "`AsyncLLM.pause_generation()` is deprecated. "
                "Please use `mode` argument instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            mode = "wait"
754
        await self.engine_core.pause_scheduler_async(mode=mode, clear_cache=clear_cache)
755
756
757
758
759
760
761
        # Small sleep to help ensure that final outputs from any in-flight requests are
        # returned prior to this method returning. These outputs come out of the engine
        # prior to the wait-for-idle completion event, but involve additional async
        # tasks in output processing.
        # Note that this is not required for correctness, just more intuitive ordering
        # of events from caller's pov.
        await asyncio.sleep(0.02)
762
763
764

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""
765
        await self.engine_core.resume_scheduler_async()
766
767
768

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""
769
        return await self.engine_core.is_scheduler_paused_async()
770

771
    async def encode(
772
        self,
773
        prompt: PromptType | EngineInput,
774
775
        pooling_params: PoolingParams,
        request_id: str,
776
777
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
778
        priority: int = 0,
779
        tokenization_kwargs: dict[str, Any] | None = None,
780
        reasoning_ended: bool | None = None,
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

796
        q: RequestOutputCollector | None = None
797
798
799
800
801
802
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
803
                tokenization_kwargs=tokenization_kwargs,
804
805
                trace_headers=trace_headers,
                priority=priority,
806
                reasoning_ended=reasoning_ended,
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
825
826
            if q is not None:
                await self.abort(q.request_id, internal=True)
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
845
846
            if q is not None:
                await self.abort(q.request_id, internal=True)
847
848
849
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
850
851
852
        finally:
            if q is not None:
                q.close()
853

854
    @property
855
    def tokenizer(self) -> TokenizerLike | None:
856
        return self.renderer.tokenizer
857

858
    def get_tokenizer(self) -> TokenizerLike:
859
        return self.renderer.get_tokenizer()
860
861

    async def is_tracing_enabled(self) -> bool:
862
        return self.observability_config.otlp_traces_endpoint is not None
863

864
    async def do_log_stats(self) -> None:
865
866
        if self.logger_manager:
            self.logger_manager.log()
867
868
869

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
870
871
        if self.errored:
            raise self.dead_error
872

873
874
    async def start_profile(self, profile_prefix: str | None = None) -> None:
        coros = [self.engine_core.profile_async(True, profile_prefix)]
875
876
877
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
878
879

    async def stop_profile(self) -> None:
880
881
882
883
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
884

885
    async def reset_mm_cache(self) -> None:
886
        await self.renderer.clear_mm_cache_async()
887
888
        await self.engine_core.reset_mm_cache_async()

889
890
891
892
893
894
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
895

896
897
898
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

899
900
    async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
        await self.engine_core.sleep_async(level, mode)
901

902
903
904
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

905
    async def wake_up(self, tags: list[str] | None = None) -> None:
906
        await self.engine_core.wake_up_async(tags)
907

908
909
910
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

911
912
913
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

914
    async def add_lora(self, lora_request: LoRARequest) -> bool:
915
        """Load a new LoRA adapter into the engine for future requests."""
916
917
918
919
920
921
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

922
    async def list_loras(self) -> set[int]:
923
924
925
926
927
928
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
929

930
931
932
    async def collective_rpc(
        self,
        method: str,
933
        timeout: float | None = None,
934
        args: tuple = (),
935
        kwargs: dict | None = None,
936
    ):
937
938
939
940
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
941
942
            method, timeout, args, kwargs
        )
943

944
945
946
947
948
949
950
951
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

952
            logger.info("Engines are still running, waiting for requests to drain...")
953
954
            await asyncio.sleep(1)  # Wait 1 second before checking again

955
956
957
958
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
959

960
961
962
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
963
964
965
966
967
968
969
970
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
971
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
972
        if old_data_parallel_size == new_data_parallel_size:
973
974
975
976
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
977
            return
978
979
980
981
982
983
984

        if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS:
            logger.info(
                "VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, "
                "waiting for requests to drain before scaling"
            )
            await self.wait_for_requests_to_drain(drain_timeout)
985
986

        # recreate stat loggers
987
988
989
990
991
992
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
993
                vllm_config=self.vllm_config,
994
                engine_idxs=list(range(new_data_parallel_size)),
995
996
                custom_stat_loggers=None,
            )
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
            # Update the mutable ref so output_handler picks up the
            # new logger without creating a circular reference via self.
            if hasattr(self, "_logger_ref"):
                self._logger_ref[0] = self.logger_manager
            self.logger_manager.log_engine_initialized()

        set_scaling_elastic_ep(True)
        try:
            await self.engine_core.scale_elastic_ep(new_data_parallel_size)
            self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
        finally:
            set_scaling_elastic_ep(False)
1009

1010
1011
    @property
    def is_running(self) -> bool:
1012
1013
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1014
1015
1016

    @property
    def is_stopped(self) -> bool:
1017
        return self.errored
1018
1019
1020

    @property
    def errored(self) -> bool:
1021
        return self.engine_core.resources.engine_dead or not self.is_running
1022
1023
1024

    @property
    def dead_error(self) -> BaseException:
1025
        return EngineDeadError()
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066

    async def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        from vllm.distributed.weight_transfer.base import (
            WeightTransferInitRequest,
        )

        if isinstance(request, WeightTransferInitRequest):
            init_info_dict = request.init_info
        else:
            raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")

        await self.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """
        Batched weight update for RL training.

        Args:
            request: Weight update request with backend-specific update info
        """

        if isinstance(request, WeightTransferUpdateRequest):
            update_info_dict = request.update_info
        else:
            raise TypeError(
                f"Expected WeightTransferUpdateRequest, got {type(request)}"
            )

        await self.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )