async_llm.py 40.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
21
from vllm.engine.arg_utils import AsyncEngineArgs
22
from vllm.engine.protocol import EngineClient
23
from vllm.inputs import PromptType, StreamingInput
24
25
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
26
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
27
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
28
from vllm.plugins.io_processors import get_io_processor
29
from vllm.pooling_params import PoolingParams
30
from vllm.renderers import BaseRenderer, merge_kwargs
31
from vllm.sampling_params import RequestOutputKind, SamplingParams
32
from vllm.tasks import SupportedTask
33
from vllm.tokenizers import TokenizerLike
34
from vllm.tracing import init_tracer
35
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
36
from vllm.usage.usage_lib import UsageContext
37
38
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
39
from vllm.v1.engine import EngineCoreRequest
40
from vllm.v1.engine.core_client import EngineCoreClient
41
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
42
from vllm.v1.engine.input_processor import InputProcessor
43
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
44
from vllm.v1.engine.parallel_sampling import ParentRequest
45
from vllm.v1.engine.utils import get_prompt_text
46
from vllm.v1.executor import Executor
47
48
49
50
51
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
52
from vllm.v1.metrics.prometheus import shutdown_prometheus
53
from vllm.v1.metrics.stats import IterationStats
54
55
56
57

logger = init_logger(__name__)


58
59
60
61
62
63
64
65
66
67
68
69
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


70
71
72
73
class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
74
        executor_class: type[Executor],
75
76
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
77
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
78
79
80
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
81
        stat_loggers: list[StatLoggerFactory] | None = None,
82
        aggregate_engine_logging: bool = False,
83
        client_addresses: dict[str, str] | None = None,
84
        client_count: int = 1,
85
        client_index: int = 0,
86
    ) -> None:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
107
108
109
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

110
        self.model_config = vllm_config.model_config
111
        self.vllm_config = vllm_config
112
        self.observability_config = vllm_config.observability_config
113
114
115
116
        tracing_endpoint = self.observability_config.otlp_traces_endpoint
        if tracing_endpoint is not None:
            init_tracer("vllm.llm_engine", tracing_endpoint)

117
        self.log_requests = log_requests
118

119
120
121
122
123
124
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
125
            logger.info(
126
127
128
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
129
            )
130

131
        self.input_processor = InputProcessor(self.vllm_config)
132
133
        self.io_processor = get_io_processor(
            self.vllm_config,
134
            self.model_config.io_processor_plugin,
135
        )
136

137
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
138
        self.output_processor = OutputProcessor(
139
140
141
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
142
        )
143
144
        if tracing_endpoint is not None:
            self.output_processor.tracing_enabled = True
145
146

        # EngineCore (starts the engine in background process).
147
        self.engine_core = EngineCoreClient.make_async_mp_client(
148
149
            vllm_config=vllm_config,
            executor_class=executor_class,
150
            log_stats=self.log_stats,
151
            client_addresses=client_addresses,
152
            client_count=client_count,
153
            client_index=client_index,
154
        )
155
156

        # Loggers.
157
        self.logger_manager: StatLoggerManager | None = None
158
159
160
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
161
                engine_idxs=self.engine_core.engine_ranks_managed,
162
                custom_stat_loggers=custom_stat_loggers,
163
                enable_default_loggers=log_stats,
164
                client_count=client_count,
165
                aggregate_engine_logging=aggregate_engine_logging,
166
167
168
            )
            self.logger_manager.log_engine_initialized()

169
170
171
172
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False

173
        self.output_handler: asyncio.Task | None = None
174
175
176
177
178
179
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
180

181
        if (
182
183
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
184
        ):
185
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
186
187
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
188
                profiler_dir,
189
            )
190
191
192
193
194
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
195
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
196
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
197
                    profiler_dir,
198
                    worker_name=worker_name,
199
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
200
201
                ),
            )
202
203
204
        else:
            self.profiler = None

205
206
    @classmethod
    def from_vllm_config(
207
208
209
210
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
211
        stat_loggers: list[StatLoggerFactory] | None = None,
212
        enable_log_requests: bool = False,
213
        aggregate_engine_logging: bool = False,
214
        disable_log_stats: bool = False,
215
        client_addresses: dict[str, str] | None = None,
216
217
        client_count: int = 1,
        client_index: int = 0,
218
219
220
221
222
223
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
224
            stat_loggers=stat_loggers,
225
            log_requests=enable_log_requests,
226
            log_stats=not disable_log_stats,
227
            aggregate_engine_logging=aggregate_engine_logging,
228
            usage_context=usage_context,
229
            client_addresses=client_addresses,
230
            client_count=client_count,
231
            client_index=client_index,
232
233
        )

234
235
236
237
238
239
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
240
        stat_loggers: list[StatLoggerFactory] | None = None,
241
    ) -> "AsyncLLM":
242
243
244
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
245
        vllm_config = engine_args.create_engine_config(usage_context)
246
        executor_class = Executor.get_class(vllm_config)
247
248
249
250
251

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
252
            log_requests=engine_args.enable_log_requests,
253
254
255
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
256
            stat_loggers=stat_loggers,
257
258
        )

259
260
261
    def __del__(self):
        self.shutdown()

262
263
264
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

265
266
        shutdown_prometheus()

267
268
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
269

270
271
272
        if input_processor := getattr(self, "input_processor", None):
            input_processor.close()

273
274
275
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
276

277
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
278
279
280
281
282
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = await self.engine_core.get_supported_tasks_async()

        return self._supported_tasks
283

284
285
286
    async def add_request(
        self,
        request_id: str,
287
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
288
289
290
291
292
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
293
        priority: int = 0,
294
295
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
296
    ) -> RequestOutputCollector:
297
298
        """Add new request to the AsyncLLM."""

299
300
301
        if self.errored:
            raise EngineDeadError()

302
        is_pooling = isinstance(params, PoolingParams)
303

304
305
306
307
308
309
310
311
312
313
314
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

315
316
317
318
319
320
321
322
323
324
325
326
327
328
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )
329

330
331
332
333
334
335
336
337
338
339
340
341
342
343
        if isinstance(prompt, AsyncGenerator):
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

344
        # Convert Input --> Request.
345
346
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
347
348
349
350
351
352
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
353
        else:
354
355
356
357
            if prompt_text is not None:
                raise ValueError(
                    "should only provide prompt_text with EngineCoreRequest"
                )
358
            request = self.input_processor.process_inputs(
359
360
361
                request_id,
                prompt,
                params,
362
363
364
365
366
367
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
368
                supported_tasks=await self.get_supported_tasks(),
369
            )
370
            prompt_text = get_prompt_text(prompt)
371

372
373
        self.input_processor.assign_request_id(request)

374
375
376
377
378
379
380
381
382
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

        # Respect pause state before accepting new requests.
        async with self._pause_cond:
            await self._pause_cond.wait_for(lambda: not self._paused)

383
384
385
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

386
387
388
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

389
        if is_pooling or params.n == 1:
390
            await self._add_request(request, prompt_text, None, 0, queue)
391
392
            return queue

393
394
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
395

396
        # Fan out child requests (for n>1).
397
        parent_request = ParentRequest(request)
398
399
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
400
            child_request = request if idx == parent_params.n - 1 else copy(request)
401
            child_request.request_id = request_id
402
            child_request.sampling_params = child_params
403
404
405
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
406
        return queue
407

408
409
410
    async def _add_request(
        self,
        request: EngineCoreRequest,
411
412
        prompt: str | None,
        parent_req: ParentRequest | None,
413
414
415
        index: int,
        queue: RequestOutputCollector,
    ):
416
        # Add the request to OutputProcessor (this process).
417
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
418

419
420
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
421

422
423
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
474
                    # TODO(nick): Avoid re-validating reused sampling parameters
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
                    prompt_text = get_prompt_text(input_chunk.prompt)
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

524
525
526
527
528
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
529
    async def generate(
530
        self,
531
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
532
533
        sampling_params: SamplingParams,
        request_id: str,
534
        *,
535
536
537
538
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
539
        priority: int = 0,
540
        data_parallel_rank: int | None = None,
541
542
543
544
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
545
            * 2) Processing the Input.
546
547
548
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

549
550
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
551
552
553
554
555
556
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

557
        q: RequestOutputCollector | None = None
558
        try:
559
560
561
562
563
564
565
566
567
568
569
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
570

571
572
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
573
574
            finished = False
            while not finished:
575
576
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
577
                out = q.get_nowait() or await q.get()
578

579
                # Note: both OutputProcessor and EngineCore handle their
580
                # own request cleanup based on finished.
581
                assert isinstance(out, RequestOutput)
582
583
584
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
585

586
        # If the request is disconnected by the client, generate()
587
588
589
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
590
591
            if q is not None:
                await self.abort(q.request_id, internal=True)
592
593
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
594
            raise
595

596
597
598
599
600
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
601

602
        # Request validation error.
603
        except ValueError as e:
604
            if self.log_requests:
605
                logger.info("Request %s failed (bad request): %s.", request_id, e)
606
            raise
607

608
609
610
611
612
613
614
615
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

616
        # Unexpected error in the generate() task (possibly recoverable).
617
        except Exception as e:
618
619
            if q is not None:
                await self.abort(q.request_id, internal=True)
620
            if self.log_requests:
621
622
623
624
625
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
626
                        "error during printing an exception of class"
627
628
629
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
630
            raise EngineGenerateError() from e
631
632
633
        finally:
            if q is not None:
                q.close()
634
635
636
637
638
639
640
641
642
643
644
645

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
646
        logger_manager = self.logger_manager
647
        input_processor = self.input_processor
648
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
649
650
651
652
653
654
655
656

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

657
658
659
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
660
661
662
663

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
664
665
666
667
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
668
669
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
670
671
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
672
673
674
675
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
676
                        if end < num_outputs:
677
678
679
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
680
681
682
683
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
684

685
686
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

687
688
689
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
690
691
692
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
693
694
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
695
                            mm_cache_stats=input_processor.stat_mm_cache(),
696
697
698
699
700
701
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
702

703
704
705
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
706
        """Abort RequestId in OutputProcessor and EngineCore."""
707

708
709
710
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
711
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
712
        await self.engine_core.abort_requests_async(all_request_ids)
713

714
        if self.log_requests:
715
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
716

717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight
                requests to finish before pausing. When ``False`` (default),
                immediately aborts any in-flight requests.
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).
        """

        async with self._pause_cond:
            if self._paused:
                return
            self._paused = True

        if not wait_for_inflight_requests:
            request_ids = list(self.output_processor.request_states.keys())
            if request_ids:
745
                await self.abort(request_ids, internal=True)
746
747
748
749
750
751
752
753
754

        # Wait for running requests to drain before clearing cache.
        if self.output_processor.has_unfinished_requests():
            await self.output_processor.wait_for_requests_to_drain()

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()
755
            await self.reset_encoder_cache()
756
757
758
759
760
761
762
763
764
765
766
767
768
769

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

770
    async def encode(
771
772
773
774
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
775
776
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
777
        priority: int = 0,
778
        tokenization_kwargs: dict[str, Any] | None = None,
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

794
        q: RequestOutputCollector | None = None
795
796
797
798
799
800
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
801
                tokenization_kwargs=tokenization_kwargs,
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
822
823
            if q is not None:
                await self.abort(q.request_id, internal=True)
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
842
843
            if q is not None:
                await self.abort(q.request_id, internal=True)
844
845
846
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
847
848
849
        finally:
            if q is not None:
                q.close()
850

851
    @property
852
    def tokenizer(self) -> TokenizerLike | None:
853
        return self.input_processor.tokenizer
854

855
856
    def get_tokenizer(self) -> TokenizerLike:
        return self.input_processor.get_tokenizer()
857

858
    @property
859
    def renderer(self) -> BaseRenderer:
860
        return self.input_processor.renderer
861
862

    async def is_tracing_enabled(self) -> bool:
863
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
864

865
    async def do_log_stats(self) -> None:
866
867
        if self.logger_manager:
            self.logger_manager.log()
868
869
870

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
871
872
        if self.errored:
            raise self.dead_error
873
874

    async def start_profile(self) -> None:
875
876
877
878
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
879
880

    async def stop_profile(self) -> None:
881
882
883
884
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
885

886
    async def reset_mm_cache(self) -> None:
887
        self.input_processor.clear_mm_cache()
888
889
        await self.engine_core.reset_mm_cache_async()

890
891
892
893
894
895
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
896

897
898
899
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

900
    async def sleep(self, level: int = 1) -> None:
901
        await self.reset_prefix_cache()
902
903
        await self.engine_core.sleep_async(level)

904
905
906
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

907
    async def wake_up(self, tags: list[str] | None = None) -> None:
908
        await self.engine_core.wake_up_async(tags)
909

910
911
912
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

913
914
915
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

916
    async def add_lora(self, lora_request: LoRARequest) -> bool:
917
        """Load a new LoRA adapter into the engine for future requests."""
918
919
920
921
922
923
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

924
    async def list_loras(self) -> set[int]:
925
926
927
928
929
930
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
931

932
933
934
    async def collective_rpc(
        self,
        method: str,
935
        timeout: float | None = None,
936
        args: tuple = (),
937
        kwargs: dict | None = None,
938
    ):
939
940
941
942
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
943
944
            method, timeout, args, kwargs
        )
945

946
947
948
949
950
951
952
953
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

954
            logger.info("Engines are still running, waiting for requests to drain...")
955
956
            await asyncio.sleep(1)  # Wait 1 second before checking again

957
958
959
960
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
961

962
963
964
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
965
966
967
968
969
970
971
972
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
973
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
974
        if old_data_parallel_size == new_data_parallel_size:
975
976
977
978
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
979
980
            return
        logger.info(
981
982
983
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
984
985
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
986
987
988
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
989
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
990
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
991
992

        # recreate stat loggers
993
994
995
996
997
998
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
999
                vllm_config=self.vllm_config,
1000
                engine_idxs=list(range(new_data_parallel_size)),
1001
1002
1003
                custom_stat_loggers=None,
            )

1004
1005
    @property
    def is_running(self) -> bool:
1006
1007
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1008
1009
1010

    @property
    def is_stopped(self) -> bool:
1011
        return self.errored
1012
1013
1014

    @property
    def errored(self) -> bool:
1015
        return self.engine_core.resources.engine_dead or not self.is_running
1016
1017
1018

    @property
    def dead_error(self) -> BaseException:
1019
        return EngineDeadError()
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060

    async def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        from vllm.distributed.weight_transfer.base import (
            WeightTransferInitRequest,
        )

        if isinstance(request, WeightTransferInitRequest):
            init_info_dict = request.init_info
        else:
            raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")

        await self.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """
        Batched weight update for RL training.

        Args:
            request: Weight update request with backend-specific update info
        """

        if isinstance(request, WeightTransferUpdateRequest):
            update_info_dict = request.update_info
        else:
            raise TypeError(
                f"Expected WeightTransferUpdateRequest, got {type(request)}"
            )

        await self.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )