async_llm_engine.py 50 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
import copy
5
import time
6
import weakref
Antoni Baum's avatar
Antoni Baum committed
7
from functools import partial
8
9
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
10
from weakref import ReferenceType
11

12
13
from typing_extensions import deprecated

14
import vllm.envs as envs
15
16
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
17
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.engine.arg_utils import AsyncEngineArgs
19
from vllm.engine.async_timeout import asyncio_timeout
20
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
21
from vllm.engine.metrics_types import StatLoggerBase
22
from vllm.engine.protocol import EngineClient
23
from vllm.executor.executor_base import ExecutorBase
24
from vllm.inputs import PromptType
25
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
26
from vllm.logger import init_logger
27
from vllm.lora.request import LoRARequest
28
29
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
30
from vllm.model_executor.layers.sampler import SamplerOutput
31
from vllm.outputs import PoolingRequestOutput, RequestOutput
32
from vllm.pooling_params import PoolingParams
33
from vllm.prompt_adapter.request import PromptAdapterRequest
34
from vllm.sampling_params import SamplingParams
35
from vllm.sequence import ExecuteModelRequest
36
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
37
from vllm.usage.usage_lib import UsageContext
38
from vllm.utils import Device, deprecate_kwargs, weak_bind
39
40

logger = init_logger(__name__)
41
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
42

Antoni Baum's avatar
Antoni Baum committed
43

44
45
46
47
class AsyncEngineDeadError(RuntimeError):
    pass


48
49
50
51
52
53
54
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
55
56

    exception = None
57
    try:
58
59
60
61
62
63
64
65
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
66
67
68
69
70
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
71
            "Task finished unexpectedly. This should never happen! "
72
            "Please open an issue on GitHub. See stack trace above for the "
73
            "actual cause.") from e
74
75


76
77
78
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
79
class AsyncStream:
80
    """A stream of RequestOutputs or PoolingRequestOutputs for a request
81
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
82

83
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
84
        self.request_id = request_id
85
        self._cancel = cancel
86
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
87
88
        self._finished = False

89
    def put(self, item: Union[RequestOutput, PoolingRequestOutput,
90
                              Exception]) -> None:
91
92
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
93

94
95
96
97
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
98
99
100
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
101
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
102
103
104
105
106

    @property
    def finished(self) -> bool:
        return self._finished

107
108
    async def generator(
        self
109
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
110
        try:
111
            while True:
112
                result = await self._queue.get()
113
                if self._is_raisable(result):
114
115
116
117
118
119
120
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
121

122
123
124
125
126
127
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
128

129
130
131
132
133
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
134
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
135
136
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
137
        self.new_requests_event = asyncio.Event()
138
139
140
141

    def __contains__(self, item):
        return item in self._request_streams

142
143
    def __len__(self) -> int:
        return len(self._request_streams)
144
145
146
147
148
149
150

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
151
            self.abort_request(request_id, exception=exc)
152
        else:
153
            # NB: tuple() used here because self.abort_request pops the stream
154
            # out of self._request_streams, so we can't iterate on it directly
155
156
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
157
158

    def process_request_output(self,
159
                               request_output: Union[RequestOutput,
160
                                                     PoolingRequestOutput],
161
162
163
164
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
165
        finished = request_output.finished
166

167
168
169
170
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
171
172
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
173
        if stream is not None:
174
            stream.put(request_output)
175
176
177
178
179
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
180

181
182
    def process_exception(self,
                          request_id: str,
183
                          exception: BaseException,
184
185
186
187
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
188
            logger.info("Finished request %s.", request_id)
189
        self.abort_request(request_id, exception=exception)
190

191
192
193
194
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
195
196
197
198
199
200
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

201
202
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
203
204
205
206
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
207
208
209

        self.new_requests_event.set()

210
211
212
        if verbose:
            logger.info("Added request %s.", request_id)

213
214
        return stream

215
216
217
    def abort_request(self,
                      request_id: str,
                      *,
218
219
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
220
                      verbose: bool = False) -> None:
221
222
        """Abort a request during next background loop iteration."""
        if verbose:
223
            logger.info("Aborted request %s.", request_id)
224

225
        self._aborted_requests.put_nowait(request_id)
226

227
228
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
229
            stream.finish(exception=exception)
230

231
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
232
233
        """Get the new requests and finished requests to be
        sent to the engine."""
234
        new_requests: List[Dict] = []
235
236
        finished_requests: Set[str] = set()

237
238
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
239
240
241
242
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
243
244
            request_id = stream.request_id
            if request_id in finished_requests:
245
                # The request has already been aborted.
246
247
248
249
250
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
251
252

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
253

254
    async def wait_for_new_requests(self):
255
256
257
258
259
260
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
261

Antoni Baum's avatar
Antoni Baum committed
262
263
264
265

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

266
267
268
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

269
    async def step_async(
270
        self, virtual_engine: int
271
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
272
273
274
275
276
277
278
279
280
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
281
282
283
284
285
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
286
287
        allow_async_output_proc = cached_outputs.allow_async_output_proc

288
289
        ctx = self.scheduler_contexts[virtual_engine]

290
291
292
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

293
294
295
296
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
297

298
            # Schedule iteration
299
300
301
302
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

303
304
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
305

306
307
308
309
310
            if not scheduler_outputs.is_empty():
                # this will cause mamba_cache/minimax_cache failed
                # to release finished_requests_ids of the last steps
                finished_requests_ids = self.scheduler[
                    virtual_engine].get_and_reset_finished_requests_ids()
311

312
            # Maybe switch from async mode to sync mode
313
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
314
                self._process_model_outputs(ctx=ctx)
315

316
317
318
319
320
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
321
322
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
323
324
        else:
            finished_requests_ids = list()
325
326
327

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
328

329
        if not scheduler_outputs.is_empty():
330
331
332
333
334
335
336
337

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

338
339
340
341
342
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
343
                virtual_engine=virtual_engine,
344
345
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
346
347
348
349
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
350
351

            if allow_async_output_proc:
352
353
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
354

355
            # Execute the model.
356
            outputs = await self.model_executor.execute_model_async(
357
                execute_model_req)
358

359
360
361
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
362
                self._update_cached_scheduler_output(virtual_engine, outputs)
363
        else:
364
365
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
366
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
367

368
369
370
371
372
373
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
374
            # Clear the cache if we have finished all the steps
375
376
377
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
378

379
380
381
382
383
384
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

385
386
387
388
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
389
390
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
391

392
            if outputs and allow_async_output_proc:
393
                assert len(
394
                    outputs
395
396
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
397
                    outputs[0], seq_group_metadata_list,
398
                    scheduler_outputs.scheduled_seq_groups)
399
400

            if not allow_async_output_proc:
401
                self._process_model_outputs(ctx=ctx)
402
403

                # Log stats.
404
                self.do_log_stats(scheduler_outputs, outputs)
405
406
407
408
409

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
410
            # Multi-step case
411
            return ctx.request_outputs
412
413
414
415

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
416
                self._process_model_outputs(ctx=ctx)
417
            assert len(ctx.output_queue) == 0
418

419
        return ctx.request_outputs
420

421
422
423
424
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

425
426
427
428
429
430
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

431
432
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
433
    async def add_request_async(
434
435
        self,
        request_id: str,
436
437
        *,
        inputs: PromptType,
438
439
440
441
442
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
443
        priority: int = 0,
444
445
446
447
448
449
450
451
452
453
454
455
456
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
457
        priority: int = 0,
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
474
            priority: int = 0,
475
476
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
477
    ) -> None:
478
        """Async version of {meth}`add_request`."""
479
480
481
482
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

483
484
485
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
486
487
488
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
489
490
        if arrival_time is None:
            arrival_time = time.time()
491

492
493
494
495
496
497
498
499
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            # We use the -2 dimension (instead of 0) in case a batched input
            # of batch size 1 is passed in.
            prompt["prompt_token_ids"] = [0
                                          ] * prompt["prompt_embeds"].shape[-2]

500
        processed_inputs = await self.input_preprocessor.preprocess_async(
501
            prompt,
502
            lora_request=lora_request,
503
504
            prompt_adapter_request=prompt_adapter_request,
        )
505

506
507
508
509
510
511
512
513
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
514
                tokenizer=await self.get_tokenizer_async(lora_request),
515
                default_guided_backend=self.decoding_config.
516
                guided_decoding_backend,
517
                reasoning_backend=self.decoding_config.reasoning_backend,
518
                model_config=self.model_config)
519

520
        self._add_processed_request(
521
            request_id=request_id,
522
523
524
525
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
526
            prompt_adapter_request=prompt_adapter_request,
527
            trace_headers=trace_headers,
528
            priority=priority,
529
        )
530

531
532
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
533

534
535
536
537
538
539
540
    async def collective_rpc_async(self,
                                   method: str,
                                   timeout: Optional[float] = None,
                                   args: tuple = (),
                                   kwargs: Optional[dict] = None):
        raise NotImplementedError

541

542
543
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
544
        default_guided_backend: str, reasoning_backend: Optional[str],
545
        model_config: ModelConfig) -> SamplingParams:
546
547
548
549
550
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
551
    if sampling_params.guided_decoding is None:
552
553
        return sampling_params

554
555
556
557
558
    # Defensively copy sampling params since guided decoding logits
    # processors can have different state for each request
    sampling_params = copy.copy(sampling_params)
    guided_decoding = sampling_params.guided_decoding

559
    logger.debug(
560
561
562
563
        "Building guided decoding logits processor. "
        "guided_decoding: %s%s", guided_decoding,
        f", reasoning_backend: {reasoning_backend}"
        if reasoning_backend is not None else "")
564
565
566
567

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
568
569
        guided_params=guided_decoding,
        tokenizer=tokenizer,
570
        reasoning_backend=reasoning_backend,
571
        model_config=model_config)
572
573
574
575
576
577
578
579
580
581
582
583

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


584
class AsyncLLMEngine(EngineClient):
585
    """An asynchronous wrapper for {class}`LLMEngine`.
586

587
    This class is used to wrap the {class}`LLMEngine` class to make it
588
    asynchronous. It uses asyncio to create a background loop that keeps
589
    processing incoming requests. The {class}`LLMEngine` is kicked by the
590
    generate method when there are requests in the waiting queue. The generate
591
    method yields the outputs from the {class}`LLMEngine` to the caller.
592
593

    Args:
594
        log_requests: Whether to log the requests.
595
596
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
597
598
        *args: Arguments for {class}`LLMEngine`.
        **kwargs: Arguments for {class}`LLMEngine`.
599
    """
600

Antoni Baum's avatar
Antoni Baum committed
601
602
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

603
604
605
    def __init__(self,
                 *args,
                 log_requests: bool = True,
606
                 start_engine_loop: bool = True,
607
                 **kwargs) -> None:
608
609
610
611
612
613
614
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

615
        self.log_requests = log_requests
616
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
617

618
619
620
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
621
622
623
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

624
625
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
626
                weak_bind(self.process_request_outputs)
627

628
        self.background_loop: Optional[asyncio.Future] = None
629
630
631
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
632
        self._background_loop_unshielded: Optional[asyncio.Task] = None
633
        self.start_engine_loop = start_engine_loop
634
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
635

636
637
638
        # Lazy initialized fields
        self._request_tracker: RequestTracker

639
640
641
642
643
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

644
    @classmethod
645
646
647
    def _get_executor_cls(cls,
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
        return LLMEngine._get_executor_cls(engine_config)
648

649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
        disable_log_requests: bool = False,
        disable_log_stats: bool = False,
    ) -> "AsyncLLMEngine":
        """Create an AsyncLLMEngine from the EngineArgs."""

        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            start_engine_loop=start_engine_loop,
            log_requests=not disable_log_requests,
            log_stats=not disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

671
672
673
674
675
676
677
678
679
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
680
681
682
683
684
685
686
687
688
689

        vllm_config = engine_args.create_engine_config(usage_context)

        async_engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
            async_engine_cls = V1AsyncLLMEngine

        return async_engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
690
691
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
692
            stat_loggers=stat_loggers,
693
694
            disable_log_stats=engine_args.disable_log_stats,
            disable_log_requests=engine_args.disable_log_requests,
yhu422's avatar
yhu422 committed
695
        )
696

697
698
    @property
    def is_running(self) -> bool:
699
        return (self.background_loop is not None
700
                and self._background_loop_unshielded is not None
701
702
703
704
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
705
706
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
707
708
709
710
711
712
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

713
    @property
714
715
716
717
718
719
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
720

721
722
723
724
725
726
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
727

728
729
730
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

731
732
733
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
734
    ) -> AnyTokenizer:
735
        return await self.engine.get_tokenizer_async(lora_request)
736

737
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
738
        """Start the background loop."""
739
740
741
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
742
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
743
            raise RuntimeError("Background loop is already running.")
744
745
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
746
747

        self._background_loop_unshielded = asyncio.get_event_loop(
748
        ).create_task(self.run_engine_loop(weakref.ref(self)))
749
        self._background_loop_unshielded.add_done_callback(
750
            partial(_log_task_completion, error_callback=self._error_callback))
751
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
752

753
754
755
756
757
758
759
760
761
762
763
764
765
766
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

767
    async def engine_step(self, virtual_engine: int) -> bool:
768
769
770
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
771

772
773
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
774
775
776

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
777
            try:
778
                await self.engine.add_request_async(**new_request)
779
780
781
782
783
784
785
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
786

787
788
        if aborted_requests:
            await self._engine_abort(aborted_requests)
789

790
        request_outputs = await self.engine.step_async(virtual_engine)
791

Antoni Baum's avatar
Antoni Baum committed
792
        # Put the outputs into the corresponding streams.
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
808
        for request_output in request_outputs:
809
            self._request_tracker.process_request_output(
810
                request_output, verbose=self.log_requests)
811
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
812

813
        return all_finished
814

Antoni Baum's avatar
Antoni Baum committed
815
    async def _engine_abort(self, request_ids: Iterable[str]):
816
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
817

818
819
820
821
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
822
        engine: Optional[AsyncLLMEngine] = engine_ref()
823
824
825
        if not engine:
            return

826
        pipeline_parallel_size = \
827
                engine.engine.parallel_config.pipeline_parallel_size
828
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
829
        while True:
830
            if not any(has_requests_in_progress):
831
                logger.debug("Waiting for new requests...")
832
833
834
835
836
837
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
838
839
840
841
842
843
844
845
846
847
848
849
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
850
                logger.debug("Got new requests!")
851
                requests_in_progress = [
852
                    asyncio.create_task(engine.engine_step(ve))
853
854
855
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
856
857
858
859

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
860
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
861
862
863
864
865
866
867
868
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
869
                    has_unfinished_requests = (
870
871
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
872
                            virtual_engine))
873
874
875
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
876
                                engine.engine_step(virtual_engine)))
877
878
879
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
880
881
882
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
883
                engine.set_errored(exc)
884
                raise
Antoni Baum's avatar
Antoni Baum committed
885
886
            await asyncio.sleep(0)

887
888
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
889
890
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
891
    def add_request(
892
893
        self,
        request_id: str,
894
895
        *,
        inputs: PromptType,
896
        params: Union[SamplingParams, PoolingParams],
897
898
899
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
900
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
901
        priority: int = 0,
902
    ) -> Coroutine[None, None, AsyncGenerator[Union[
903
            RequestOutput, PoolingRequestOutput], None]]:
904
905
906
907
908
909
910
911
912
913
914
915
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
916
        priority: int = 0,
917
    ) -> Coroutine[None, None, AsyncGenerator[Union[
918
            RequestOutput, PoolingRequestOutput], None]]:
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
934
        priority: int = 0,
935
936
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
937
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
938
939
940
941
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

942
        if not self.is_running:
943
944
945
946
947
948
949
950
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
951

952
953
954
955
956
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

957
        stream = self._request_tracker.add_request(
958
            request_id,
959
            verbose=self.log_requests,
960
            prompt=prompt,
961
            params=params,
962
            arrival_time=arrival_time or time.time(),
963
            lora_request=lora_request,
964
            trace_headers=trace_headers,
965
966
967
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
968

969
        return stream.generator()
970

971
    async def generate(
972
        self,
973
        prompt: PromptType,
974
975
        sampling_params: SamplingParams,
        request_id: str,
976
        lora_request: Optional[LoRARequest] = None,
977
        trace_headers: Optional[Mapping[str, str]] = None,
978
979
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
980
    ) -> AsyncGenerator[RequestOutput, None]:
981
982
983
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
984
985
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
986
987

        Args:
988
            prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
989
                for more details about the format of each input.
990
991
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
992
            lora_request: LoRA request to use for generation, if any.
993
            trace_headers: OpenTelemetry trace headers.
994
            prompt_adapter_request: Prompt Adapter request to use
995
                                            for generation, if any.
996
997
            priority: The priority of the request.
                Only applicable with priority scheduling.
998
999

        Yields:
1000
1001
            The output `RequestOutput` objects from the LLMEngine
            for the request.
1002
1003
1004
1005

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
1006
              {meth}`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1019
            >>> # note that engine_args here is AsyncEngineArgs instance
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1046
        """
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    prompt_adapter_request=prompt_adapter_request,
                    priority=priority,
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1061
1062
1063

    async def encode(
        self,
1064
        prompt: PromptType,
1065
1066
1067
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1068
        trace_headers: Optional[Mapping[str, str]] = None,
1069
        priority: int = 0,
1070
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
1071
        """Generate outputs for a request from a pooling model.
1072
1073
1074
1075
1076
1077

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1078
            prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
1079
                for more details about the format of each input.
1080
1081
1082
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1083
            trace_headers: OpenTelemetry trace headers.
1084
1085
            priority: The priority of the request.
                Only applicable with priority scheduling.
1086
1087

        Yields:
1088
            The output `PoolingRequestOutput` objects from the LLMEngine
1089
1090
1091
            for the request.

        Details:
1092
1093
1094
1095
1096
1097
1098
1099
1100
        - If the engine is not running, start the background loop,
            which iteratively invokes
            {meth}`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
            to process the waiting requests.
        - Add the request to the engine's `RequestTracker`.
            On the next background loop, this request will be sent to
            the underlying engine.
            Also, a corresponding `AsyncStream` will be created.
        - Wait for the request outputs from `AsyncStream` and yield them.
1101
1102

        Example:
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        ```
        # Please refer to entrypoints/api_server.py for
        # the complete example.
    
        # initialize the engine and the example input
        # note that engine_args here is AsyncEngineArgs instance
        engine = AsyncLLMEngine.from_engine_args(engine_args)
        example_input = {
            "input": "What is LLM?",
            "request_id": 0,
        }
    
        # start the generation
        results_generator = engine.encode(
        example_input["input"],
        PoolingParams(),
        example_input["request_id"])
    
        # get the results
        final_output = None
        async for request_output in results_generator:
            if await request.is_disconnected():
                # Abort the request if the client disconnects.
                await engine.abort(request_id)
                # Return or raise an error
                ...
            final_output = request_output
    
        # Process and return the final output
        ...
        ```
1134
        """
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    pooling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
            ):
                yield LLMEngine.validate_output(output, PoolingRequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1148

Antoni Baum's avatar
Antoni Baum committed
1149
1150
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1151

Antoni Baum's avatar
Antoni Baum committed
1152
1153
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1154

Antoni Baum's avatar
Antoni Baum committed
1155
1156
1157
        Args:
            request_id: The unique id of the request.
        """
1158
1159
1160
1161
1162
1163
1164
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1165
        return self._abort(request_id)
1166

Antoni Baum's avatar
Antoni Baum committed
1167
    def _abort(self, request_id: str) -> None:
1168
1169
1170
1171
1172
1173
1174
1175
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1176
        self._request_tracker.abort_request(request_id,
1177
                                            exception=asyncio.CancelledError,
1178
                                            verbose=self.log_requests)
1179

1180
1181
1182
1183
    async def get_vllm_config(self) -> VllmConfig:
        """Get the vllm configuration of the vLLM engine."""
        return self.engine.get_vllm_config()

1184
1185
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1186
        return self.engine.get_model_config()
1187

1188
1189
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1190
        return self.engine.get_parallel_config()
1191

1192
1193
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1194
        return self.engine.get_decoding_config()
1195

1196
1197
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1198
        return self.engine.get_scheduler_config()
1199
1200
1201

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1202
        return self.engine.get_lora_config()
1203

1204
1205
1206
1207
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1208
        self.engine.do_log_stats()
1209

1210
    async def check_health(self) -> None:
1211
1212
1213
1214
1215
1216
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1217
        await self.engine.check_health_async()
1218
        logger.debug("Health check took %fs", time.perf_counter() - t)
1219
1220

    async def is_tracing_enabled(self) -> bool:
1221
        return self.engine.is_tracing_enabled()
1222
1223

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1224
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1225
1226

    def remove_logger(self, logger_name: str) -> None:
1227
        self.engine.remove_logger(logger_name=logger_name)
1228
1229

    async def start_profile(self) -> None:
1230
        self.engine.start_profile()
1231
1232

    async def stop_profile(self) -> None:
1233
        self.engine.stop_profile()
1234

1235
1236
1237
    async def reset_mm_cache(self) -> None:
        self.engine.reset_mm_cache()

1238
1239
1240
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        self.engine.reset_prefix_cache(device)
1241

1242
1243
1244
    async def sleep(self, level: int = 1) -> None:
        self.engine.sleep(level)

1245
1246
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
1247

1248
1249
1250
    async def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

1251
1252
1253
    async def add_lora(self, lora_request: LoRARequest) -> None:
        self.engine.add_lora(lora_request)

1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine.collective_rpc_async(method, timeout, args,
                                                      kwargs)

1265
1266

# TODO(v1): Remove this class proxy when V1 goes default.
1267
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
1268
1269
1270
    from vllm.v1.engine.async_llm import AsyncLLM

    AsyncLLMEngine = AsyncLLM  # type: ignore