async_llm.py 40.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
18
19
20
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
21
from vllm.engine.arg_utils import AsyncEngineArgs
22
from vllm.engine.protocol import EngineClient
23
from vllm.inputs import PromptType, StreamingInput
24
25
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
26
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
27
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
28
from vllm.plugins.io_processors import get_io_processor
29
from vllm.pooling_params import PoolingParams
30
from vllm.renderers import BaseRenderer, merge_kwargs
31
from vllm.sampling_params import RequestOutputKind, SamplingParams
32
from vllm.tasks import SupportedTask
33
from vllm.tokenizers import TokenizerLike
34
from vllm.tracing import init_tracer
35
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
36
from vllm.usage.usage_lib import UsageContext
37
38
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
39
from vllm.v1.engine import EngineCoreRequest
40
from vllm.v1.engine.core_client import EngineCoreClient
41
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
42
from vllm.v1.engine.input_processor import InputProcessor
43
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
44
from vllm.v1.engine.parallel_sampling import ParentRequest
45
from vllm.v1.engine.utils import get_prompt_text
46
from vllm.v1.executor import Executor
47
48
49
50
51
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
52
from vllm.v1.metrics.prometheus import shutdown_prometheus
53
from vllm.v1.metrics.stats import IterationStats
54
55
56
57

logger = init_logger(__name__)


58
59
60
61
62
63
64
65
66
67
68
69
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


70
71
72
73
class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
74
        executor_class: type[Executor],
75
76
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
77
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
78
79
80
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
81
        stat_loggers: list[StatLoggerFactory] | None = None,
82
        aggregate_engine_logging: bool = False,
83
        client_addresses: dict[str, str] | None = None,
84
        client_count: int = 1,
85
        client_index: int = 0,
86
    ) -> None:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
107
108
109
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

110
        self.model_config = vllm_config.model_config
111
        self.vllm_config = vllm_config
112
        self.observability_config = vllm_config.observability_config
113
        self.log_requests = log_requests
114

115
116
117
118
119
120
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
121
            logger.info(
122
123
124
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
125
            )
126

127
        self.input_processor = InputProcessor(self.vllm_config)
128
129
        self.io_processor = get_io_processor(
            self.vllm_config,
130
            self.model_config.io_processor_plugin,
131
        )
132

133
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
134
        self.output_processor = OutputProcessor(
135
136
137
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
138
        )
139
140
141
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
142
            self.output_processor.tracer = tracer
143
144

        # EngineCore (starts the engine in background process).
145
        self.engine_core = EngineCoreClient.make_async_mp_client(
146
147
            vllm_config=vllm_config,
            executor_class=executor_class,
148
            log_stats=self.log_stats,
149
            client_addresses=client_addresses,
150
            client_count=client_count,
151
            client_index=client_index,
152
        )
153
154

        # Loggers.
155
        self.logger_manager: StatLoggerManager | None = None
156
157
158
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
159
                engine_idxs=self.engine_core.engine_ranks_managed,
160
                custom_stat_loggers=custom_stat_loggers,
161
                enable_default_loggers=log_stats,
162
                client_count=client_count,
163
                aggregate_engine_logging=aggregate_engine_logging,
164
165
166
            )
            self.logger_manager.log_engine_initialized()

167
168
169
170
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False

171
        self.output_handler: asyncio.Task | None = None
172
173
174
175
176
177
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
178

179
        if (
180
181
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
182
        ):
183
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
184
185
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
186
                profiler_dir,
187
            )
188
189
190
191
192
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
193
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
194
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
195
                    profiler_dir,
196
                    worker_name=worker_name,
197
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
198
199
                ),
            )
200
201
202
        else:
            self.profiler = None

203
204
    @classmethod
    def from_vllm_config(
205
206
207
208
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
209
        stat_loggers: list[StatLoggerFactory] | None = None,
210
        enable_log_requests: bool = False,
211
        aggregate_engine_logging: bool = False,
212
        disable_log_stats: bool = False,
213
        client_addresses: dict[str, str] | None = None,
214
215
        client_count: int = 1,
        client_index: int = 0,
216
217
218
219
220
221
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
222
            stat_loggers=stat_loggers,
223
            log_requests=enable_log_requests,
224
            log_stats=not disable_log_stats,
225
            aggregate_engine_logging=aggregate_engine_logging,
226
            usage_context=usage_context,
227
            client_addresses=client_addresses,
228
            client_count=client_count,
229
            client_index=client_index,
230
231
        )

232
233
234
235
236
237
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
238
        stat_loggers: list[StatLoggerFactory] | None = None,
239
    ) -> "AsyncLLM":
240
241
242
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
243
        vllm_config = engine_args.create_engine_config(usage_context)
244
        executor_class = Executor.get_class(vllm_config)
245
246
247
248
249

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
250
            log_requests=engine_args.enable_log_requests,
251
252
253
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
254
            stat_loggers=stat_loggers,
255
256
        )

257
258
259
    def __del__(self):
        self.shutdown()

260
261
262
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

263
264
        shutdown_prometheus()

265
266
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
267

268
269
270
        if input_processor := getattr(self, "input_processor", None):
            input_processor.close()

271
272
273
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
274

275
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
276
277
278
279
280
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = await self.engine_core.get_supported_tasks_async()

        return self._supported_tasks
281

282
283
284
    async def add_request(
        self,
        request_id: str,
285
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
286
287
288
289
290
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
291
        priority: int = 0,
292
293
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
294
    ) -> RequestOutputCollector:
295
296
        """Add new request to the AsyncLLM."""

297
298
299
        if self.errored:
            raise EngineDeadError()

300
        is_pooling = isinstance(params, PoolingParams)
301

302
303
304
305
306
307
308
309
310
311
312
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

313
314
315
316
317
318
319
320
321
322
323
324
325
326
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )
327

328
329
330
331
332
333
334
335
336
337
338
339
340
341
        if isinstance(prompt, AsyncGenerator):
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

342
        # Convert Input --> Request.
343
344
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
345
346
347
348
349
350
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
351
        else:
352
353
354
355
            if prompt_text is not None:
                raise ValueError(
                    "should only provide prompt_text with EngineCoreRequest"
                )
356
            request = self.input_processor.process_inputs(
357
358
359
                request_id,
                prompt,
                params,
360
361
362
363
364
365
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
366
                supported_tasks=await self.get_supported_tasks(),
367
            )
368
            prompt_text = get_prompt_text(prompt)
369

370
371
        self.input_processor.assign_request_id(request)

372
373
374
375
376
377
378
379
380
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

        # Respect pause state before accepting new requests.
        async with self._pause_cond:
            await self._pause_cond.wait_for(lambda: not self._paused)

381
382
383
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

384
385
386
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

387
        if is_pooling or params.n == 1:
388
            await self._add_request(request, prompt_text, None, 0, queue)
389
390
            return queue

391
392
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
393

394
        # Fan out child requests (for n>1).
395
        parent_request = ParentRequest(request)
396
397
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
398
            child_request = request if idx == parent_params.n - 1 else copy(request)
399
            child_request.request_id = request_id
400
            child_request.sampling_params = child_params
401
402
403
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
404
        return queue
405

406
407
408
    async def _add_request(
        self,
        request: EngineCoreRequest,
409
410
        prompt: str | None,
        parent_req: ParentRequest | None,
411
412
413
        index: int,
        queue: RequestOutputCollector,
    ):
414
        # Add the request to OutputProcessor (this process).
415
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
416

417
418
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
419

420
421
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
422

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
472
                    # TODO(nick): Avoid re-validating reused sampling parameters
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
                    prompt_text = get_prompt_text(input_chunk.prompt)
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

522
523
524
525
526
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
527
    async def generate(
528
        self,
529
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
530
531
        sampling_params: SamplingParams,
        request_id: str,
532
        *,
533
534
535
536
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
537
        priority: int = 0,
538
        data_parallel_rank: int | None = None,
539
540
541
542
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
543
            * 2) Processing the Input.
544
545
546
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

547
548
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
549
550
551
552
553
554
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

555
        q: RequestOutputCollector | None = None
556
        try:
557
558
559
560
561
562
563
564
565
566
567
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
568

569
570
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
571
572
            finished = False
            while not finished:
573
574
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
575
                out = q.get_nowait() or await q.get()
576

577
                # Note: both OutputProcessor and EngineCore handle their
578
                # own request cleanup based on finished.
579
                assert isinstance(out, RequestOutput)
580
581
582
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
583

584
        # If the request is disconnected by the client, generate()
585
586
587
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
588
589
            if q is not None:
                await self.abort(q.request_id, internal=True)
590
591
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
592
            raise
593

594
595
596
597
598
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
599

600
        # Request validation error.
601
        except ValueError as e:
602
            if self.log_requests:
603
                logger.info("Request %s failed (bad request): %s.", request_id, e)
604
            raise
605

606
607
608
609
610
611
612
613
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

614
        # Unexpected error in the generate() task (possibly recoverable).
615
        except Exception as e:
616
617
            if q is not None:
                await self.abort(q.request_id, internal=True)
618
            if self.log_requests:
619
620
621
622
623
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
624
                        "error during printing an exception of class"
625
626
627
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
628
            raise EngineGenerateError() from e
629
630
631
        finally:
            if q is not None:
                q.close()
632
633
634
635
636
637
638
639
640
641
642
643

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
644
        logger_manager = self.logger_manager
645
        input_processor = self.input_processor
646
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
647
648
649
650
651
652
653
654

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

655
656
657
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
658
659
660
661

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
662
663
664
665
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
666
667
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
668
669
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
670
671
672
673
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
674
                        if end < num_outputs:
675
676
677
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
678
679
680
681
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
682

683
684
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

685
686
687
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
688
689
690
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
691
692
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
693
                            mm_cache_stats=input_processor.stat_mm_cache(),
694
695
696
697
698
699
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
700

701
702
703
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
704
        """Abort RequestId in OutputProcessor and EngineCore."""
705

706
707
708
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
709
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
710
        await self.engine_core.abort_requests_async(all_request_ids)
711

712
        if self.log_requests:
713
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
714

715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight
                requests to finish before pausing. When ``False`` (default),
                immediately aborts any in-flight requests.
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).
        """

        async with self._pause_cond:
            if self._paused:
                return
            self._paused = True

        if not wait_for_inflight_requests:
            request_ids = list(self.output_processor.request_states.keys())
            if request_ids:
743
                await self.abort(request_ids, internal=True)
744
745
746
747
748
749
750
751
752

        # Wait for running requests to drain before clearing cache.
        if self.output_processor.has_unfinished_requests():
            await self.output_processor.wait_for_requests_to_drain()

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()
753
            await self.reset_encoder_cache()
754
755
756
757
758
759
760
761
762
763
764
765
766
767

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

768
    async def encode(
769
770
771
772
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
773
774
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
775
        priority: int = 0,
776
        tokenization_kwargs: dict[str, Any] | None = None,
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

792
        q: RequestOutputCollector | None = None
793
794
795
796
797
798
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
799
                tokenization_kwargs=tokenization_kwargs,
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
820
821
            if q is not None:
                await self.abort(q.request_id, internal=True)
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
840
841
            if q is not None:
                await self.abort(q.request_id, internal=True)
842
843
844
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
845
846
847
        finally:
            if q is not None:
                q.close()
848

849
    @property
850
    def tokenizer(self) -> TokenizerLike | None:
851
        return self.input_processor.tokenizer
852

853
854
    def get_tokenizer(self) -> TokenizerLike:
        return self.input_processor.get_tokenizer()
855

856
    @property
857
    def renderer(self) -> BaseRenderer:
858
        return self.input_processor.renderer
859
860

    async def is_tracing_enabled(self) -> bool:
861
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
862

863
    async def do_log_stats(self) -> None:
864
865
        if self.logger_manager:
            self.logger_manager.log()
866
867
868

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
869
870
        if self.errored:
            raise self.dead_error
871
872

    async def start_profile(self) -> None:
873
874
875
876
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
877
878

    async def stop_profile(self) -> None:
879
880
881
882
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
883

884
    async def reset_mm_cache(self) -> None:
885
        self.input_processor.clear_mm_cache()
886
887
        await self.engine_core.reset_mm_cache_async()

888
889
890
891
892
893
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
894

895
896
897
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

898
    async def sleep(self, level: int = 1) -> None:
899
        await self.reset_prefix_cache()
900
901
        await self.engine_core.sleep_async(level)

902
903
904
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

905
    async def wake_up(self, tags: list[str] | None = None) -> None:
906
        await self.engine_core.wake_up_async(tags)
907

908
909
910
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

911
912
913
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

914
    async def add_lora(self, lora_request: LoRARequest) -> bool:
915
        """Load a new LoRA adapter into the engine for future requests."""
916
917
918
919
920
921
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

922
    async def list_loras(self) -> set[int]:
923
924
925
926
927
928
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
929

930
931
932
    async def collective_rpc(
        self,
        method: str,
933
        timeout: float | None = None,
934
        args: tuple = (),
935
        kwargs: dict | None = None,
936
    ):
937
938
939
940
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
941
942
            method, timeout, args, kwargs
        )
943

944
945
946
947
948
949
950
951
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

952
            logger.info("Engines are still running, waiting for requests to drain...")
953
954
            await asyncio.sleep(1)  # Wait 1 second before checking again

955
956
957
958
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
959

960
961
962
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
963
964
965
966
967
968
969
970
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
971
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
972
        if old_data_parallel_size == new_data_parallel_size:
973
974
975
976
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
977
978
            return
        logger.info(
979
980
981
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
982
983
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
984
985
986
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
987
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
988
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
989
990

        # recreate stat loggers
991
992
993
994
995
996
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
997
                vllm_config=self.vllm_config,
998
                engine_idxs=list(range(new_data_parallel_size)),
999
1000
1001
                custom_stat_loggers=None,
            )

1002
1003
    @property
    def is_running(self) -> bool:
1004
1005
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
1006
1007
1008

    @property
    def is_stopped(self) -> bool:
1009
        return self.errored
1010
1011
1012

    @property
    def errored(self) -> bool:
1013
        return self.engine_core.resources.engine_dead or not self.is_running
1014
1015
1016

    @property
    def dead_error(self) -> BaseException:
1017
        return EngineDeadError()
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058

    async def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        from vllm.distributed.weight_transfer.base import (
            WeightTransferInitRequest,
        )

        if isinstance(request, WeightTransferInitRequest):
            init_info_dict = request.init_info
        else:
            raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")

        await self.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
        """
        Batched weight update for RL training.

        Args:
            request: Weight update request with backend-specific update info
        """

        if isinstance(request, WeightTransferUpdateRequest):
            update_info_dict = request.update_info
        else:
            raise TypeError(
                f"Expected WeightTransferUpdateRequest, got {type(request)}"
            )

        await self.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )