async_llm_engine.py 50.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import copy
6
import time
7
import weakref
Antoni Baum's avatar
Antoni Baum committed
8
from functools import partial
9
10
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
11
from weakref import ReferenceType
12

13
14
from typing_extensions import deprecated

15
import vllm.envs as envs
16
17
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
18
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.engine.arg_utils import AsyncEngineArgs
20
from vllm.engine.async_timeout import asyncio_timeout
21
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
22
from vllm.engine.metrics_types import StatLoggerBase
23
from vllm.engine.protocol import EngineClient
24
from vllm.executor.executor_base import ExecutorBase
25
from vllm.inputs import PromptType
26
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.logger import init_logger
28
from vllm.lora.request import LoRARequest
29
30
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
31
from vllm.model_executor.layers.sampler import SamplerOutput
32
from vllm.outputs import PoolingRequestOutput, RequestOutput
33
from vllm.pooling_params import PoolingParams
34
from vllm.prompt_adapter.request import PromptAdapterRequest
35
from vllm.sampling_params import SamplingParams
36
from vllm.sequence import ExecuteModelRequest
37
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
38
from vllm.usage.usage_lib import UsageContext
39
from vllm.utils import Device, deprecate_kwargs, weak_bind
40
41

logger = init_logger(__name__)
42
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
43

Antoni Baum's avatar
Antoni Baum committed
44

45
46
47
48
class AsyncEngineDeadError(RuntimeError):
    pass


49
50
51
52
53
54
55
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
56
57

    exception = None
58
    try:
59
60
61
62
63
64
65
66
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
67
68
69
70
71
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
72
            "Task finished unexpectedly. This should never happen! "
73
            "Please open an issue on GitHub. See stack trace above for the "
74
            "actual cause.") from e
75
76


77
78
79
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
80
class AsyncStream:
81
    """A stream of RequestOutputs or PoolingRequestOutputs for a request
82
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
83

84
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
85
        self.request_id = request_id
86
        self._cancel = cancel
87
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
88
89
        self._finished = False

90
    def put(self, item: Union[RequestOutput, PoolingRequestOutput,
91
                              Exception]) -> None:
92
93
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
94

95
96
97
98
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
99
100
101
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
102
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
103
104
105
106
107

    @property
    def finished(self) -> bool:
        return self._finished

108
109
    async def generator(
        self
110
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
111
        try:
112
            while True:
113
                result = await self._queue.get()
114
                if self._is_raisable(result):
115
116
117
118
119
120
121
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
122

123
124
125
126
127
128
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
129

130
131
132
133
134
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
135
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
136
137
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
138
        self.new_requests_event = asyncio.Event()
139
140
141
142

    def __contains__(self, item):
        return item in self._request_streams

143
144
    def __len__(self) -> int:
        return len(self._request_streams)
145
146
147
148
149
150
151

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
152
            self.abort_request(request_id, exception=exc)
153
        else:
154
            # NB: tuple() used here because self.abort_request pops the stream
155
            # out of self._request_streams, so we can't iterate on it directly
156
157
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
158
159

    def process_request_output(self,
160
                               request_output: Union[RequestOutput,
161
                                                     PoolingRequestOutput],
162
163
164
165
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
166
        finished = request_output.finished
167

168
169
170
171
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
172
173
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
174
        if stream is not None:
175
            stream.put(request_output)
176
177
178
179
180
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
181

182
183
    def process_exception(self,
                          request_id: str,
184
                          exception: BaseException,
185
186
187
188
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
189
            logger.info("Finished request %s.", request_id)
190
        self.abort_request(request_id, exception=exception)
191

192
193
194
195
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
196
197
198
199
200
201
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

202
203
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
204
205
206
207
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
208
209
210

        self.new_requests_event.set()

211
212
213
        if verbose:
            logger.info("Added request %s.", request_id)

214
215
        return stream

216
217
218
    def abort_request(self,
                      request_id: str,
                      *,
219
220
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
221
                      verbose: bool = False) -> None:
222
223
        """Abort a request during next background loop iteration."""
        if verbose:
224
            logger.info("Aborted request %s.", request_id)
225

226
        self._aborted_requests.put_nowait(request_id)
227

228
229
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
230
            stream.finish(exception=exception)
231

232
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
233
234
        """Get the new requests and finished requests to be
        sent to the engine."""
235
        new_requests: List[Dict] = []
236
237
        finished_requests: Set[str] = set()

238
239
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
240
241
242
243
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
244
245
            request_id = stream.request_id
            if request_id in finished_requests:
246
                # The request has already been aborted.
247
248
249
250
251
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
252
253

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
254

255
    async def wait_for_new_requests(self):
256
257
258
259
260
261
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
262

Antoni Baum's avatar
Antoni Baum committed
263
264
265
266

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

267
268
269
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

270
    async def step_async(
271
        self, virtual_engine: int
272
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
273
274
275
276
277
278
279
280
281
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
282
283
284
285
286
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
287
288
        allow_async_output_proc = cached_outputs.allow_async_output_proc

289
290
        ctx = self.scheduler_contexts[virtual_engine]

291
292
293
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

294
295
296
297
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
298

299
            # Schedule iteration
300
301
302
303
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

304
305
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
306

307
308
309
310
311
            if not scheduler_outputs.is_empty():
                # this will cause mamba_cache/minimax_cache failed
                # to release finished_requests_ids of the last steps
                finished_requests_ids = self.scheduler[
                    virtual_engine].get_and_reset_finished_requests_ids()
312

313
            # Maybe switch from async mode to sync mode
314
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
315
                self._process_model_outputs(ctx=ctx)
316

317
318
319
320
321
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
322
323
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
324
325
        else:
            finished_requests_ids = list()
326
327
328

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
329

330
        if not scheduler_outputs.is_empty():
331
332
333
334
335
336
337
338

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

339
340
341
342
343
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
344
                virtual_engine=virtual_engine,
345
346
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
347
348
349
350
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
351
352

            if allow_async_output_proc:
353
354
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
355

356
            # Execute the model.
357
            outputs = await self.model_executor.execute_model_async(
358
                execute_model_req)
359

360
361
362
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
363
                self._update_cached_scheduler_output(virtual_engine, outputs)
364
        else:
365
366
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
367
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
368

369
370
371
372
373
374
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
375
            # Clear the cache if we have finished all the steps
376
377
378
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
379

380
381
382
383
384
385
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

386
387
388
389
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
390
391
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
392

393
            if outputs and allow_async_output_proc:
394
                assert len(
395
                    outputs
396
397
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
398
                    outputs[0], seq_group_metadata_list,
399
                    scheduler_outputs.scheduled_seq_groups)
400
401

            if not allow_async_output_proc:
402
                self._process_model_outputs(ctx=ctx)
403
404

                # Log stats.
405
                self.do_log_stats(scheduler_outputs, outputs)
406
407
408
409
410

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
411
            # Multi-step case
412
            return ctx.request_outputs
413
414
415
416

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
417
                self._process_model_outputs(ctx=ctx)
418
            assert len(ctx.output_queue) == 0
419

420
        return ctx.request_outputs
421

422
423
424
425
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

426
427
428
429
430
431
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

432
433
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
434
    async def add_request_async(
435
436
        self,
        request_id: str,
437
438
        *,
        inputs: PromptType,
439
440
441
442
443
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
444
        priority: int = 0,
445
        data_parallel_rank: Optional[int] = None,
446
447
448
449
450
451
452
453
454
455
456
457
458
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
459
        priority: int = 0,
460
        data_parallel_rank: Optional[int] = None,
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
477
            priority: int = 0,
478
            data_parallel_rank: Optional[int] = None,
479
480
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
481
    ) -> None:
482
483
        """Async version of
        [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request]."""
484
485
486
487
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

488
489
490
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
491
492
493
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
494
495
        if arrival_time is None:
            arrival_time = time.time()
496

497
498
499
500
501
502
503
504
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            # We use the -2 dimension (instead of 0) in case a batched input
            # of batch size 1 is passed in.
            prompt["prompt_token_ids"] = [0
                                          ] * prompt["prompt_embeds"].shape[-2]

505
        processed_inputs = await self.input_preprocessor.preprocess_async(
506
            prompt,
507
            lora_request=lora_request,
508
509
            prompt_adapter_request=prompt_adapter_request,
        )
510

511
512
513
514
515
516
517
518
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
519
                tokenizer=await self.get_tokenizer_async(lora_request),
520
                default_guided_backend=self.decoding_config.
521
                guided_decoding_backend,
522
                reasoning_backend=self.decoding_config.reasoning_backend,
523
                model_config=self.model_config)
524

525
        self._add_processed_request(
526
            request_id=request_id,
527
528
529
530
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
531
            prompt_adapter_request=prompt_adapter_request,
532
            trace_headers=trace_headers,
533
            priority=priority,
534
        )
535

536
537
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
538

539
540
541
542
543
544
545
    async def collective_rpc_async(self,
                                   method: str,
                                   timeout: Optional[float] = None,
                                   args: tuple = (),
                                   kwargs: Optional[dict] = None):
        raise NotImplementedError

546

547
548
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
549
        default_guided_backend: str, reasoning_backend: Optional[str],
550
        model_config: ModelConfig) -> SamplingParams:
551
552
553
554
555
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
556
    if sampling_params.guided_decoding is None:
557
558
        return sampling_params

559
560
561
562
563
    # Defensively copy sampling params since guided decoding logits
    # processors can have different state for each request
    sampling_params = copy.copy(sampling_params)
    guided_decoding = sampling_params.guided_decoding

564
    logger.debug(
565
566
567
568
        "Building guided decoding logits processor. "
        "guided_decoding: %s%s", guided_decoding,
        f", reasoning_backend: {reasoning_backend}"
        if reasoning_backend is not None else "")
569
570
571
572

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
573
574
        guided_params=guided_decoding,
        tokenizer=tokenizer,
575
        reasoning_backend=reasoning_backend,
576
        model_config=model_config)
577
578
579
580
581
582
583
584
585
586
587
588

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


589
class AsyncLLMEngine(EngineClient):
590
    """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
591

592
593
594
595
596
597
    This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to
    make it asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked
    by the generate method when there are requests in the waiting queue. The
    generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine]
    to the caller.
598
599

    Args:
600
        log_requests: Whether to log the requests.
601
602
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
603
604
        *args: Arguments for [`LLMEngine`][vllm.LLMEngine].
        **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine].
605
    """
606

Antoni Baum's avatar
Antoni Baum committed
607
608
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

609
610
611
    def __init__(self,
                 *args,
                 log_requests: bool = True,
612
                 start_engine_loop: bool = True,
613
                 **kwargs) -> None:
614
615
616
617
618
619
620
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

621
        self.log_requests = log_requests
622
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
623

624
625
626
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
627
628
629
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

630
631
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
632
                weak_bind(self.process_request_outputs)
633

634
        self.background_loop: Optional[asyncio.Future] = None
635
636
637
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
638
        self._background_loop_unshielded: Optional[asyncio.Task] = None
639
        self.start_engine_loop = start_engine_loop
640
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
641

642
643
644
        # Lazy initialized fields
        self._request_tracker: RequestTracker

645
646
647
648
649
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

650
    @classmethod
651
652
653
    def _get_executor_cls(cls,
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
        return LLMEngine._get_executor_cls(engine_config)
654

655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
        disable_log_requests: bool = False,
        disable_log_stats: bool = False,
    ) -> "AsyncLLMEngine":
        """Create an AsyncLLMEngine from the EngineArgs."""

        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            start_engine_loop=start_engine_loop,
            log_requests=not disable_log_requests,
            log_stats=not disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

677
678
679
680
681
682
683
684
685
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
686
687
688
689
690
691
692
693
694
695

        vllm_config = engine_args.create_engine_config(usage_context)

        async_engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
            async_engine_cls = V1AsyncLLMEngine

        return async_engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
696
697
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
698
            stat_loggers=stat_loggers,
699
700
            disable_log_stats=engine_args.disable_log_stats,
            disable_log_requests=engine_args.disable_log_requests,
yhu422's avatar
yhu422 committed
701
        )
702

703
704
    @property
    def is_running(self) -> bool:
705
        return (self.background_loop is not None
706
                and self._background_loop_unshielded is not None
707
708
709
710
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
711
712
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
713
714
715
716
717
718
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

719
    @property
720
721
722
723
724
725
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
726

727
728
729
730
731
732
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
733

734
735
736
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

737
738
739
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
740
    ) -> AnyTokenizer:
741
        return await self.engine.get_tokenizer_async(lora_request)
742

743
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
744
        """Start the background loop."""
745
746
747
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
748
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
749
            raise RuntimeError("Background loop is already running.")
750
751
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
752
753

        self._background_loop_unshielded = asyncio.get_event_loop(
754
        ).create_task(self.run_engine_loop(weakref.ref(self)))
755
        self._background_loop_unshielded.add_done_callback(
756
            partial(_log_task_completion, error_callback=self._error_callback))
757
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
758

759
760
761
762
763
764
765
766
767
768
769
770
771
772
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

773
    async def engine_step(self, virtual_engine: int) -> bool:
774
775
776
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
777

778
779
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
780
781
782

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
783
            try:
784
                await self.engine.add_request_async(**new_request)
785
786
787
788
789
790
791
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
792

793
794
        if aborted_requests:
            await self._engine_abort(aborted_requests)
795

796
        request_outputs = await self.engine.step_async(virtual_engine)
797

Antoni Baum's avatar
Antoni Baum committed
798
        # Put the outputs into the corresponding streams.
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
814
        for request_output in request_outputs:
815
            self._request_tracker.process_request_output(
816
                request_output, verbose=self.log_requests)
817
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
818

819
        return all_finished
820

Antoni Baum's avatar
Antoni Baum committed
821
    async def _engine_abort(self, request_ids: Iterable[str]):
822
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
823

824
825
826
827
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
828
        engine: Optional[AsyncLLMEngine] = engine_ref()
829
830
831
        if not engine:
            return

832
        pipeline_parallel_size = \
833
                engine.engine.parallel_config.pipeline_parallel_size
834
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
835
        while True:
836
            if not any(has_requests_in_progress):
837
                logger.debug("Waiting for new requests...")
838
839
840
841
842
843
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
844
845
846
847
848
849
850
851
852
853
854
855
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
856
                logger.debug("Got new requests!")
857
                requests_in_progress = [
858
                    asyncio.create_task(engine.engine_step(ve))
859
860
861
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
862
863
864
865

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
866
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
867
868
869
870
871
872
873
874
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
875
                    has_unfinished_requests = (
876
877
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
878
                            virtual_engine))
879
880
881
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
882
                                engine.engine_step(virtual_engine)))
883
884
885
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
886
887
888
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
889
                engine.set_errored(exc)
890
                raise
Antoni Baum's avatar
Antoni Baum committed
891
892
            await asyncio.sleep(0)

893
894
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
895
896
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
897
    def add_request(
898
899
        self,
        request_id: str,
900
901
        *,
        inputs: PromptType,
902
        params: Union[SamplingParams, PoolingParams],
903
904
905
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
906
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
907
        priority: int = 0,
908
        data_parallel_rank: Optional[int] = None,
909
    ) -> Coroutine[None, None, AsyncGenerator[Union[
910
            RequestOutput, PoolingRequestOutput], None]]:
911
912
913
914
915
916
917
918
919
920
921
922
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
923
        priority: int = 0,
924
        data_parallel_rank: Optional[int] = None,
925
    ) -> Coroutine[None, None, AsyncGenerator[Union[
926
            RequestOutput, PoolingRequestOutput], None]]:
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
942
        priority: int = 0,
943
        data_parallel_rank: Optional[int] = None,
944
945
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
946
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
947
948
949
950
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

951
        if not self.is_running:
952
953
954
955
956
957
958
959
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
960

961
962
963
964
965
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

966
        stream = self._request_tracker.add_request(
967
            request_id,
968
            verbose=self.log_requests,
969
            prompt=prompt,
970
            params=params,
971
            arrival_time=arrival_time or time.time(),
972
            lora_request=lora_request,
973
            trace_headers=trace_headers,
974
975
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
976
            data_parallel_rank=data_parallel_rank,
977
        )
Antoni Baum's avatar
Antoni Baum committed
978

979
        return stream.generator()
980

981
    async def generate(
982
        self,
983
        prompt: PromptType,
984
985
        sampling_params: SamplingParams,
        request_id: str,
986
        lora_request: Optional[LoRARequest] = None,
987
        trace_headers: Optional[Mapping[str, str]] = None,
988
989
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
990
        data_parallel_rank: Optional[int] = None,
991
    ) -> AsyncGenerator[RequestOutput, None]:
992
993
994
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
995
996
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
997
998

        Args:
999
1000
1001
            prompt: The prompt to the LLM. See
                [`PromptType`][vllm.inputs.PromptType] for more details about
                the format of each input.
1002
1003
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
1004
            lora_request: LoRA request to use for generation, if any.
1005
            trace_headers: OpenTelemetry trace headers.
1006
            prompt_adapter_request: Prompt Adapter request to use
1007
                                            for generation, if any.
1008
1009
            priority: The priority of the request.
                Only applicable with priority scheduling.
1010
1011
            data_parallel_rank: The (global) data parallel rank that must
                handle this request. Only applicable if DP is enabled.
1012
        Yields:
1013
1014
            The output `RequestOutput` objects from the LLMEngine
            for the request.
1015
1016
1017
1018

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
1019
              [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step]
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1032
            >>> # note that engine_args here is AsyncEngineArgs instance
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1059
        """
1060
1061
1062
1063
1064
1065
1066
1067
1068
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    prompt_adapter_request=prompt_adapter_request,
                    priority=priority,
1069
                    data_parallel_rank=data_parallel_rank,
1070
1071
1072
1073
1074
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1075
1076
1077

    async def encode(
        self,
1078
        prompt: PromptType,
1079
1080
1081
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1082
        trace_headers: Optional[Mapping[str, str]] = None,
1083
        priority: int = 0,
1084
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
1085
        """Generate outputs for a request from a pooling model.
1086
1087
1088
1089
1090
1091

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1092
1093
1094
            prompt: The prompt to the LLM. See
                [`PromptType`][vllm.inputs.PromptType] for more details about
                the format of each input.
1095
1096
1097
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1098
            trace_headers: OpenTelemetry trace headers.
1099
1100
            priority: The priority of the request.
                Only applicable with priority scheduling.
1101
1102

        Yields:
1103
            The output `PoolingRequestOutput` objects from the LLMEngine
1104
1105
1106
            for the request.

        Details:
1107
1108
1109
1110
1111
1112
1113
1114
1115
            - If the engine is not running, start the background loop,
                which iteratively invokes
                [`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][]
                to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
                On the next background loop, this request will be sent to
                the underlying engine.
                Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.
1116
1117

        Example:
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
        ```
        # Please refer to entrypoints/api_server.py for
        # the complete example.
    
        # initialize the engine and the example input
        # note that engine_args here is AsyncEngineArgs instance
        engine = AsyncLLMEngine.from_engine_args(engine_args)
        example_input = {
            "input": "What is LLM?",
            "request_id": 0,
        }
    
        # start the generation
        results_generator = engine.encode(
        example_input["input"],
        PoolingParams(),
        example_input["request_id"])
    
        # get the results
        final_output = None
        async for request_output in results_generator:
            if await request.is_disconnected():
                # Abort the request if the client disconnects.
                await engine.abort(request_id)
                # Return or raise an error
                ...
            final_output = request_output
    
        # Process and return the final output
        ...
        ```
1149
        """
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    pooling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
            ):
                yield LLMEngine.validate_output(output, PoolingRequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1163

Antoni Baum's avatar
Antoni Baum committed
1164
1165
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1166

Antoni Baum's avatar
Antoni Baum committed
1167
1168
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1169

Antoni Baum's avatar
Antoni Baum committed
1170
1171
1172
        Args:
            request_id: The unique id of the request.
        """
1173
1174
1175
1176
1177
1178
1179
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1180
        return self._abort(request_id)
1181

Antoni Baum's avatar
Antoni Baum committed
1182
    def _abort(self, request_id: str) -> None:
1183
1184
1185
1186
1187
1188
1189
1190
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1191
        self._request_tracker.abort_request(request_id,
1192
                                            exception=asyncio.CancelledError,
1193
                                            verbose=self.log_requests)
1194

1195
1196
1197
1198
    async def get_vllm_config(self) -> VllmConfig:
        """Get the vllm configuration of the vLLM engine."""
        return self.engine.get_vllm_config()

1199
1200
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1201
        return self.engine.get_model_config()
1202

1203
1204
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1205
        return self.engine.get_parallel_config()
1206

1207
1208
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1209
        return self.engine.get_decoding_config()
1210

1211
1212
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1213
        return self.engine.get_scheduler_config()
1214
1215
1216

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1217
        return self.engine.get_lora_config()
1218

1219
1220
1221
1222
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1223
        self.engine.do_log_stats()
1224

1225
    async def check_health(self) -> None:
1226
1227
1228
1229
1230
1231
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1232
        await self.engine.check_health_async()
1233
        logger.debug("Health check took %fs", time.perf_counter() - t)
1234
1235

    async def is_tracing_enabled(self) -> bool:
1236
        return self.engine.is_tracing_enabled()
1237
1238

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1239
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1240
1241

    def remove_logger(self, logger_name: str) -> None:
1242
        self.engine.remove_logger(logger_name=logger_name)
1243
1244

    async def start_profile(self) -> None:
1245
        self.engine.start_profile()
1246
1247

    async def stop_profile(self) -> None:
1248
        self.engine.stop_profile()
1249

1250
1251
1252
    async def reset_mm_cache(self) -> None:
        self.engine.reset_mm_cache()

1253
1254
1255
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        self.engine.reset_prefix_cache(device)
1256

1257
1258
1259
    async def sleep(self, level: int = 1) -> None:
        self.engine.sleep(level)

1260
1261
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
1262

1263
1264
1265
    async def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

1266
1267
1268
    async def add_lora(self, lora_request: LoRARequest) -> None:
        self.engine.add_lora(lora_request)

1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine.collective_rpc_async(method, timeout, args,
                                                      kwargs)

1280
1281

# TODO(v1): Remove this class proxy when V1 goes default.
1282
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
1283
1284
1285
    from vllm.v1.engine.async_llm import AsyncLLM

    AsyncLLMEngine = AsyncLLM  # type: ignore