async_llm.py 38.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
import warnings
8
from collections.abc import AsyncGenerator, Iterable, Mapping
9
from copy import copy
10
from typing import Any
11

12
import torch
13

14
import vllm.envs as envs
15
from vllm import TokensPrompt
16
from vllm.config import VllmConfig
17
from vllm.engine.arg_utils import AsyncEngineArgs
18
from vllm.engine.protocol import EngineClient
19
from vllm.inputs import PromptType
20
from vllm.inputs.data import StreamingInput
21
22
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
23
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
24
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
25
from vllm.plugins.io_processors import get_io_processor
26
from vllm.pooling_params import PoolingParams
27
from vllm.renderers import RendererLike, merge_kwargs
28
from vllm.sampling_params import RequestOutputKind, SamplingParams
29
from vllm.tasks import SupportedTask
30
from vllm.tokenizers import TokenizerLike
31
from vllm.tracing import init_tracer
32
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
33
from vllm.usage.usage_lib import UsageContext
34
35
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
36
from vllm.v1.engine import EngineCoreRequest
37
from vllm.v1.engine.core_client import EngineCoreClient
38
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
39
from vllm.v1.engine.input_processor import InputProcessor
40
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
41
from vllm.v1.engine.parallel_sampling import ParentRequest
42
from vllm.v1.engine.utils import get_prompt_text
43
from vllm.v1.executor import Executor
44
45
46
47
48
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
49
from vllm.v1.metrics.prometheus import shutdown_prometheus
50
from vllm.v1.metrics.stats import IterationStats
51
52
53
54

logger = init_logger(__name__)


55
56
57
58
59
60
61
62
63
64
65
66
class InputStreamError(Exception):
    """Wrapper for errors from the input stream generator.

    This is used to propagate errors from the user's input generator
    without wrapping them in EngineGenerateError.
    """

    def __init__(self, cause: Exception):
        self.cause = cause
        super().__init__(str(cause))


67
68
69
70
class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
71
        executor_class: type[Executor],
72
73
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
74
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
75
76
77
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
78
        stat_loggers: list[StatLoggerFactory] | None = None,
79
        aggregate_engine_logging: bool = False,
80
        client_addresses: dict[str, str] | None = None,
81
        client_count: int = 1,
82
        client_index: int = 0,
83
    ) -> None:
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
104
105
106
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

107
        self.model_config = vllm_config.model_config
108
        self.vllm_config = vllm_config
109
        self.observability_config = vllm_config.observability_config
110
        self.log_requests = log_requests
111

112
113
114
115
116
117
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
118
            logger.info(
119
120
121
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
122
            )
123

124
        self.input_processor = InputProcessor(self.vllm_config)
125
126
        self.io_processor = get_io_processor(
            self.vllm_config,
127
            self.model_config.io_processor_plugin,
128
        )
129

130
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
131
        self.output_processor = OutputProcessor(
132
133
134
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
135
        )
136
137
138
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
139
            self.output_processor.tracer = tracer
140
141

        # EngineCore (starts the engine in background process).
142
        self.engine_core = EngineCoreClient.make_async_mp_client(
143
144
            vllm_config=vllm_config,
            executor_class=executor_class,
145
            log_stats=self.log_stats,
146
            client_addresses=client_addresses,
147
            client_count=client_count,
148
            client_index=client_index,
149
        )
150
151

        # Loggers.
152
        self.logger_manager: StatLoggerManager | None = None
153
154
155
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
156
                engine_idxs=self.engine_core.engine_ranks_managed,
157
                custom_stat_loggers=custom_stat_loggers,
158
                enable_default_loggers=log_stats,
159
                client_count=client_count,
160
                aggregate_engine_logging=aggregate_engine_logging,
161
162
163
            )
            self.logger_manager.log_engine_initialized()

164
165
166
167
        # Pause / resume state for async RL workflows.
        self._pause_cond = asyncio.Condition()
        self._paused = False

168
        self.output_handler: asyncio.Task | None = None
169
170
171
172
173
174
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
175

176
        if (
177
178
            vllm_config.profiler_config.profiler == "torch"
            and not vllm_config.profiler_config.ignore_frontend
179
        ):
180
            profiler_dir = vllm_config.profiler_config.torch_profiler_dir
181
182
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
183
                profiler_dir,
184
            )
185
186
187
188
189
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
190
                with_stack=vllm_config.profiler_config.torch_profiler_with_stack,
191
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
192
                    profiler_dir,
193
                    worker_name=worker_name,
194
                    use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip,
195
196
                ),
            )
197
198
199
        else:
            self.profiler = None

200
201
    @classmethod
    def from_vllm_config(
202
203
204
205
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
206
        stat_loggers: list[StatLoggerFactory] | None = None,
207
        enable_log_requests: bool = False,
208
        aggregate_engine_logging: bool = False,
209
        disable_log_stats: bool = False,
210
        client_addresses: dict[str, str] | None = None,
211
212
        client_count: int = 1,
        client_index: int = 0,
213
214
215
216
217
218
    ) -> "AsyncLLM":
        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
219
            stat_loggers=stat_loggers,
220
            log_requests=enable_log_requests,
221
            log_stats=not disable_log_stats,
222
            aggregate_engine_logging=aggregate_engine_logging,
223
            usage_context=usage_context,
224
            client_addresses=client_addresses,
225
            client_count=client_count,
226
            client_index=client_index,
227
228
        )

229
230
231
232
233
234
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
235
        stat_loggers: list[StatLoggerFactory] | None = None,
236
    ) -> "AsyncLLM":
237
238
239
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
240
        vllm_config = engine_args.create_engine_config(usage_context)
241
        executor_class = Executor.get_class(vllm_config)
242
243
244
245
246

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
247
            log_requests=engine_args.enable_log_requests,
248
249
250
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
251
            stat_loggers=stat_loggers,
252
253
        )

254
255
256
    def __del__(self):
        self.shutdown()

257
258
259
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

260
261
        shutdown_prometheus()

262
263
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
264

265
266
267
        if input_processor := getattr(self, "input_processor", None):
            input_processor.close()

268
269
270
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
271

272
273
274
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

275
276
277
    async def add_request(
        self,
        request_id: str,
278
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
279
280
281
282
283
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
284
        priority: int = 0,
285
286
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
287
    ) -> RequestOutputCollector:
288
289
        """Add new request to the AsyncLLM."""

290
291
292
        if self.errored:
            raise EngineDeadError()

293
        is_pooling = isinstance(params, PoolingParams)
294

295
296
297
298
299
300
301
302
303
304
305
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and not is_pooling
            and params.prompt_logprobs
        ):
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
                "prompt logprobs"
            )

306
307
308
309
310
311
312
313
314
315
316
317
318
319
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )
320

321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if isinstance(prompt, AsyncGenerator):
            # Streaming input case.
            return await self._add_streaming_input_request(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )

335
        # Convert Input --> Request.
336
337
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
338
339
340
341
342
343
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
344
        else:
345
346
347
348
            if prompt_text is not None:
                raise ValueError(
                    "should only provide prompt_text with EngineCoreRequest"
                )
349
            request = self.input_processor.process_inputs(
350
351
352
                request_id,
                prompt,
                params,
353
354
355
356
357
358
                arrival_time=arrival_time,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
359
            )
360
            prompt_text = get_prompt_text(prompt)
361

362
363
        self.input_processor.assign_request_id(request)

364
365
366
367
368
369
370
371
372
        # We start the output_handler on the first call to add_request() so
        # we can call __init__ before the event loop, which enables us
        # to handle startup failure gracefully in the OpenAI server.
        self._run_output_handler()

        # Respect pause state before accepting new requests.
        async with self._pause_cond:
            await self._pause_cond.wait_for(lambda: not self._paused)

373
374
375
        # Create a new output collector for the request.
        queue = RequestOutputCollector(params.output_kind, request.request_id)

376
377
378
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

379
        if is_pooling or params.n == 1:
380
            await self._add_request(request, prompt_text, None, 0, queue)
381
382
            return queue

383
384
        parent_params = params
        assert isinstance(parent_params, SamplingParams)
385

386
        # Fan out child requests (for n>1).
387
        parent_request = ParentRequest(request)
388
389
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
390
            child_request = request if idx == parent_params.n - 1 else copy(request)
391
            child_request.request_id = request_id
392
            child_request.sampling_params = child_params
393
394
395
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
396
        return queue
397

398
399
400
    async def _add_request(
        self,
        request: EngineCoreRequest,
401
402
        prompt: str | None,
        parent_req: ParentRequest | None,
403
404
405
        index: int,
        queue: RequestOutputCollector,
    ):
406
        # Add the request to OutputProcessor (this process).
407
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
408

409
410
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
411

412
413
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
414

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
    async def _add_streaming_input_request(
        self,
        request_id: str,
        input_stream: AsyncGenerator[StreamingInput, None],
        sampling_params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
        priority: int = 0,
        data_parallel_rank: int | None = None,
    ) -> RequestOutputCollector:
        self._validate_streaming_input_sampling_params(sampling_params)

        inputs = dict(
            arrival_time=arrival_time,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            trace_headers=trace_headers,
            priority=priority,
            data_parallel_rank=data_parallel_rank,
        )

        if not sampling_params.skip_clone:
            sampling_params = sampling_params.clone()
            sampling_params.skip_clone = True

        # Create request for validation, also used as the finished signal
        # once the input stream is closed.
        final_req = self.input_processor.process_inputs(
            request_id=request_id,
            prompt=TokensPrompt(prompt_token_ids=[0]),
            params=sampling_params,
            **inputs,  # type: ignore[arg-type]
        )
        self.input_processor.assign_request_id(final_req)
        internal_req_id = final_req.request_id

        queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id)

        async def handle_inputs():
            cancelled = False
            try:
                async for input_chunk in input_stream:
                    sp = input_chunk.sampling_params
                    if sp:
                        self._validate_streaming_input_sampling_params(sp)
                    else:
                        sp = sampling_params
                    req = self.input_processor.process_inputs(
                        request_id=internal_req_id,
                        prompt=input_chunk.prompt,
                        params=sp,
                        resumable=True,
                        **inputs,  # type: ignore[arg-type]
                    )
                    req.external_req_id = request_id
                    if req.prompt_embeds is not None:
                        raise ValueError(
                            "prompt_embeds not supported for streaming inputs"
                        )
                    prompt_text = get_prompt_text(input_chunk.prompt)
                    await self._add_request(req, prompt_text, None, 0, queue)
            except (asyncio.CancelledError, GeneratorExit):
                cancelled = True
            except Exception as error:
                # Wrap in InputStreamError so generate() can propagate it
                # without wrapping in EngineGenerateError.
                queue.put(InputStreamError(error))
            finally:
                queue._input_stream_task = None
                if not cancelled:
                    # Send empty final request to indicate that inputs have
                    # finished. Don't send if cancelled (session was aborted).
                    await self._add_request(final_req, None, None, 0, queue)

        # Ensure output handler is running.
        self._run_output_handler()

        queue._input_stream_task = asyncio.create_task(handle_inputs())
        return queue

    @staticmethod
    def _validate_streaming_input_sampling_params(
        params: SamplingParams | PoolingParams,
    ):
        if (
            not isinstance(params, SamplingParams)
            or params.n > 1
            or params.output_kind == RequestOutputKind.FINAL_ONLY
            or params.stop
        ):
            raise ValueError(
                "Input streaming not currently supported "
                "for pooling models, n > 1, request_kind = FINAL_ONLY "
                "or with stop strings."
            )

513
514
515
516
517
    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
518
    async def generate(
519
        self,
520
        prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
521
522
        sampling_params: SamplingParams,
        request_id: str,
523
        *,
524
525
526
527
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
528
        priority: int = 0,
529
        data_parallel_rank: int | None = None,
530
531
532
533
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
534
            * 2) Processing the Input.
535
536
537
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

538
539
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
540
541
542
543
544
545
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

546
        q: RequestOutputCollector | None = None
547
        try:
548
549
550
551
552
553
554
555
556
557
558
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
559

560
561
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
562
563
            finished = False
            while not finished:
564
565
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
566
                out = q.get_nowait() or await q.get()
567

568
                # Note: both OutputProcessor and EngineCore handle their
569
                # own request cleanup based on finished.
570
                assert isinstance(out, RequestOutput)
571
572
573
                finished = out.finished
                if out is not STREAM_FINISHED:
                    yield out
574

575
        # If the request is disconnected by the client, generate()
576
577
578
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
579
580
            if q is not None:
                await self.abort(q.request_id, internal=True)
581
582
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
583
            raise
584

585
586
587
588
589
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
590

591
        # Request validation error.
592
        except ValueError as e:
593
            if self.log_requests:
594
                logger.info("Request %s failed (bad request): %s.", request_id, e)
595
            raise
596

597
598
599
600
601
602
603
604
        # Error from input stream generator - propagate directly.
        except InputStreamError as e:
            if q is not None:
                await self.abort(q.request_id, internal=True)
            if self.log_requests:
                logger.info("Request %s failed (input error): %s.", request_id, e)
            raise e.cause from e

605
        # Unexpected error in the generate() task (possibly recoverable).
606
        except Exception as e:
607
608
            if q is not None:
                await self.abort(q.request_id, internal=True)
609
            if self.log_requests:
610
611
612
613
614
                try:
                    s = f"{e.__class__.__name__}: {e}"
                except Exception as e2:
                    s = (
                        f"{e.__class__.__name__}: "
615
                        "error during printing an exception of class"
616
617
618
                        + e2.__class__.__name__
                    )
                logger.info("Request %s failed due to %s.", request_id, s)
619
            raise EngineGenerateError() from e
620
621
622
        finally:
            if q is not None:
                q.close()
623
624
625
626
627
628
629
630
631
632
633
634

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
635
        logger_manager = self.logger_manager
636
        input_processor = self.input_processor
637
        chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
638
639
640
641
642
643
644
645

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

646
647
648
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
649
650
651
652

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
653
654
655
656
                    engine_core_outputs = outputs.outputs
                    for start in range(0, num_outputs, chunk_size):
                        end = start + chunk_size
                        outputs_slice = engine_core_outputs[start:end]
657
658
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
659
660
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
661
662
663
664
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
665
                        if end < num_outputs:
666
667
668
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
669
670
671
672
                        if processed_outputs.reqs_to_abort:
                            await engine_core.abort_requests_async(
                                processed_outputs.reqs_to_abort
                            )
673

674
675
                    output_processor.update_scheduler_stats(outputs.scheduler_stats)

676
677
678
                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
679
680
681
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
682
683
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
684
                            mm_cache_stats=input_processor.stat_mm_cache(),
685
686
687
688
689
690
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
691

692
693
694
    async def abort(
        self, request_id: str | Iterable[str], internal: bool = False
    ) -> None:
695
        """Abort RequestId in OutputProcessor and EngineCore."""
696

697
698
699
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
700
        all_request_ids = self.output_processor.abort_requests(request_ids, internal)
701
        await self.engine_core.abort_requests_async(all_request_ids)
702

703
        if self.log_requests:
704
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
705

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
    async def pause_generation(
        self,
        *,
        wait_for_inflight_requests: bool = False,
        clear_cache: bool = True,
    ) -> None:
        """
        Pause generation to allow model weight updates.

        New generation/encoding requests are blocked until resume.

        Args:
            wait_for_inflight_requests: When ``True`` waits for in-flight
                requests to finish before pausing. When ``False`` (default),
                immediately aborts any in-flight requests.
            clear_cache: Whether to clear KV cache and prefix cache after
                draining. Set to ``False`` to preserve cache for faster resume.
                Default is ``True`` (clear caches).
        """

        async with self._pause_cond:
            if self._paused:
                return
            self._paused = True

        if not wait_for_inflight_requests:
            request_ids = list(self.output_processor.request_states.keys())
            if request_ids:
734
                await self.abort(request_ids, internal=True)
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757

        # Wait for running requests to drain before clearing cache.
        if self.output_processor.has_unfinished_requests():
            await self.output_processor.wait_for_requests_to_drain()

        # Clear cache
        if clear_cache:
            await self.reset_prefix_cache()
            await self.reset_mm_cache()

    async def resume_generation(self) -> None:
        """Resume generation after :meth:`pause_generation`."""

        async with self._pause_cond:
            self._paused = False
            self._pause_cond.notify_all()  # Wake up all waiting requests

    async def is_paused(self) -> bool:
        """Return whether the engine is currently paused."""

        async with self._pause_cond:
            return self._paused

758
    async def encode(
759
760
761
762
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
763
764
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
765
        priority: int = 0,
766
        tokenization_kwargs: dict[str, Any] | None = None,
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

782
        q: RequestOutputCollector | None = None
783
784
785
786
787
788
        try:
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
789
                tokenization_kwargs=tokenization_kwargs,
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
810
811
            if q is not None:
                await self.abort(q.request_id, internal=True)
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
830
831
            if q is not None:
                await self.abort(q.request_id, internal=True)
832
833
834
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
835
836
837
        finally:
            if q is not None:
                q.close()
838

839
    @property
840
    def tokenizer(self) -> TokenizerLike | None:
841
        return self.input_processor.tokenizer
842

843
844
    def get_tokenizer(self) -> TokenizerLike:
        return self.input_processor.get_tokenizer()
845

846
847
848
    @property
    def renderer(self) -> RendererLike:
        return self.input_processor.renderer
849
850

    async def is_tracing_enabled(self) -> bool:
851
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
852

853
    async def do_log_stats(self) -> None:
854
855
        if self.logger_manager:
            self.logger_manager.log()
856
857
858

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
859
860
        if self.errored:
            raise self.dead_error
861
862

    async def start_profile(self) -> None:
863
864
865
866
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
867
868

    async def stop_profile(self) -> None:
869
870
871
872
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
873

874
    async def reset_mm_cache(self) -> None:
875
        self.input_processor.clear_mm_cache()
876
877
        await self.engine_core.reset_mm_cache_async()

878
879
880
881
882
883
    async def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return await self.engine_core.reset_prefix_cache_async(
            reset_running_requests, reset_connector
        )
884

885
886
887
    async def reset_encoder_cache(self) -> None:
        await self.engine_core.reset_encoder_cache_async()

888
    async def sleep(self, level: int = 1) -> None:
889
        await self.reset_prefix_cache()
890
891
        await self.engine_core.sleep_async(level)

892
893
894
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

895
    async def wake_up(self, tags: list[str] | None = None) -> None:
896
        await self.engine_core.wake_up_async(tags)
897

898
899
900
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

901
902
903
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

904
    async def add_lora(self, lora_request: LoRARequest) -> bool:
905
        """Load a new LoRA adapter into the engine for future requests."""
906
907
908
909
910
911
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

912
    async def list_loras(self) -> set[int]:
913
914
915
916
917
918
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
919

920
921
922
    async def collective_rpc(
        self,
        method: str,
923
        timeout: float | None = None,
924
        args: tuple = (),
925
        kwargs: dict | None = None,
926
    ):
927
928
929
930
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
931
932
            method, timeout, args, kwargs
        )
933

934
935
936
937
938
939
940
941
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

942
            logger.info("Engines are still running, waiting for requests to drain...")
943
944
            await asyncio.sleep(1)  # Wait 1 second before checking again

945
946
947
948
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
949

950
951
952
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
953
954
955
956
957
958
959
960
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
961
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
962
        if old_data_parallel_size == new_data_parallel_size:
963
964
965
966
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
967
968
            return
        logger.info(
969
970
971
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
972
973
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
974
975
976
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
977
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
978
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
979
980

        # recreate stat loggers
981
982
983
984
985
986
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
987
                vllm_config=self.vllm_config,
988
                engine_idxs=list(range(new_data_parallel_size)),
989
990
991
                custom_stat_loggers=None,
            )

992
993
    @property
    def is_running(self) -> bool:
994
995
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
996
997
998

    @property
    def is_stopped(self) -> bool:
999
        return self.errored
1000
1001
1002

    @property
    def errored(self) -> bool:
1003
        return self.engine_core.resources.engine_dead or not self.is_running
1004
1005
1006

    @property
    def dead_error(self) -> BaseException:
1007
        return EngineDeadError()