async_llm_engine.py 49 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
import copy
5
import time
6
import weakref
Antoni Baum's avatar
Antoni Baum committed
7
from functools import partial
8
9
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
10
from weakref import ReferenceType
11

12
13
from typing_extensions import deprecated

14
import vllm.envs as envs
15
16
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
17
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.engine.arg_utils import AsyncEngineArgs
19
from vllm.engine.async_timeout import asyncio_timeout
20
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
21
from vllm.engine.metrics_types import StatLoggerBase
22
from vllm.engine.protocol import EngineClient
23
from vllm.executor.executor_base import ExecutorBase
24
from vllm.inputs import PromptType
25
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
26
from vllm.logger import init_logger
27
from vllm.lora.request import LoRARequest
28
29
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
30
from vllm.model_executor.layers.sampler import SamplerOutput
31
from vllm.outputs import PoolingRequestOutput, RequestOutput
32
from vllm.pooling_params import PoolingParams
33
from vllm.prompt_adapter.request import PromptAdapterRequest
34
from vllm.sampling_params import SamplingParams
35
from vllm.sequence import ExecuteModelRequest
36
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
37
from vllm.usage.usage_lib import UsageContext
38
from vllm.utils import Device, deprecate_kwargs, weak_bind
39
40

logger = init_logger(__name__)
41
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
42

Antoni Baum's avatar
Antoni Baum committed
43

44
45
46
47
class AsyncEngineDeadError(RuntimeError):
    pass


48
49
50
51
52
53
54
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
55
56

    exception = None
57
    try:
58
59
60
61
62
63
64
65
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
66
67
68
69
70
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
71
            "Task finished unexpectedly. This should never happen! "
72
            "Please open an issue on GitHub. See stack trace above for the "
73
            "actual cause.") from e
74
75


76
77
78
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
79
class AsyncStream:
80
    """A stream of RequestOutputs or PoolingRequestOutputs for a request
81
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
82

83
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
84
        self.request_id = request_id
85
        self._cancel = cancel
86
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
87
88
        self._finished = False

89
    def put(self, item: Union[RequestOutput, PoolingRequestOutput,
90
                              Exception]) -> None:
91
92
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
93

94
95
96
97
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
98
99
100
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
101
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
102
103
104
105
106

    @property
    def finished(self) -> bool:
        return self._finished

107
108
    async def generator(
        self
109
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
110
        try:
111
            while True:
112
                result = await self._queue.get()
113
                if self._is_raisable(result):
114
115
116
117
118
119
120
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
121

122
123
124
125
126
127
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
128

129
130
131
132
133
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
134
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
135
136
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
137
        self.new_requests_event = asyncio.Event()
138
139
140
141

    def __contains__(self, item):
        return item in self._request_streams

142
143
    def __len__(self) -> int:
        return len(self._request_streams)
144
145
146
147
148
149
150

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
151
            self.abort_request(request_id, exception=exc)
152
        else:
153
            # NB: tuple() used here because self.abort_request pops the stream
154
            # out of self._request_streams, so we can't iterate on it directly
155
156
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
157
158

    def process_request_output(self,
159
                               request_output: Union[RequestOutput,
160
                                                     PoolingRequestOutput],
161
162
163
164
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
165
        finished = request_output.finished
166

167
168
169
170
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
171
172
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
173
        if stream is not None:
174
            stream.put(request_output)
175
176
177
178
179
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
180

181
182
    def process_exception(self,
                          request_id: str,
183
                          exception: BaseException,
184
185
186
187
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
188
            logger.info("Finished request %s.", request_id)
189
        self.abort_request(request_id, exception=exception)
190

191
192
193
194
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
195
196
197
198
199
200
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

201
202
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
203
204
205
206
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
207
208
209

        self.new_requests_event.set()

210
211
212
        if verbose:
            logger.info("Added request %s.", request_id)

213
214
        return stream

215
216
217
    def abort_request(self,
                      request_id: str,
                      *,
218
219
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
220
                      verbose: bool = False) -> None:
221
222
        """Abort a request during next background loop iteration."""
        if verbose:
223
            logger.info("Aborted request %s.", request_id)
224

225
        self._aborted_requests.put_nowait(request_id)
226

227
228
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
229
            stream.finish(exception=exception)
230

231
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
232
233
        """Get the new requests and finished requests to be
        sent to the engine."""
234
        new_requests: List[Dict] = []
235
236
        finished_requests: Set[str] = set()

237
238
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
239
240
241
242
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
243
244
            request_id = stream.request_id
            if request_id in finished_requests:
245
                # The request has already been aborted.
246
247
248
249
250
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
251
252

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
253

254
    async def wait_for_new_requests(self):
255
256
257
258
259
260
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
261

Antoni Baum's avatar
Antoni Baum committed
262
263
264
265

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

266
267
268
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

269
    async def step_async(
270
        self, virtual_engine: int
271
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
272
273
274
275
276
277
278
279
280
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
281
282
283
284
285
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
286
287
        allow_async_output_proc = cached_outputs.allow_async_output_proc

288
289
        ctx = self.scheduler_contexts[virtual_engine]

290
291
292
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

293
294
295
296
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
297

298
            # Schedule iteration
299
300
301
302
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

303
304
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
305

306
307
308
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

309
            # Maybe switch from async mode to sync mode
310
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
311
                self._process_model_outputs(ctx=ctx)
312

313
314
315
316
317
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
318
319
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
320
321
        else:
            finished_requests_ids = list()
322
323
324

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
325

326
        if not scheduler_outputs.is_empty():
327
328
329
330
331
332
333
334

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

335
336
337
338
339
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
340
                virtual_engine=virtual_engine,
341
342
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
343
344
345
346
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
347
348

            if allow_async_output_proc:
349
350
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
351

352
            # Execute the model.
353
            outputs = await self.model_executor.execute_model_async(
354
                execute_model_req)
355

356
357
358
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
359
                self._update_cached_scheduler_output(virtual_engine, outputs)
360
        else:
361
362
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
363
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
364

365
366
367
368
369
370
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
371
            # Clear the cache if we have finished all the steps
372
373
374
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
375

376
377
378
379
380
381
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

382
383
384
385
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
386
387
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
388

389
            if outputs and allow_async_output_proc:
390
                assert len(
391
                    outputs
392
393
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
394
                    outputs[0], seq_group_metadata_list,
395
                    scheduler_outputs.scheduled_seq_groups)
396
397

            if not allow_async_output_proc:
398
                self._process_model_outputs(ctx=ctx)
399
400

                # Log stats.
401
                self.do_log_stats(scheduler_outputs, outputs)
402
403
404
405
406

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
407
            # Multi-step case
408
            return ctx.request_outputs
409
410
411
412

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
413
                self._process_model_outputs(ctx=ctx)
414
            assert len(ctx.output_queue) == 0
415

416
        return ctx.request_outputs
417

418
419
420
421
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

422
423
424
425
426
427
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

428
429
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
430
    async def add_request_async(
431
432
        self,
        request_id: str,
433
434
        *,
        inputs: PromptType,
435
436
437
438
439
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
440
        priority: int = 0,
441
442
443
444
445
446
447
448
449
450
451
452
453
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
454
        priority: int = 0,
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
471
            priority: int = 0,
472
473
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
474
    ) -> None:
475
        """Async version of :meth:`add_request`."""
476
477
478
479
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

480
481
482
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
483
484
485
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
486
487
        if arrival_time is None:
            arrival_time = time.time()
488

489
490
491
492
        if self.tokenizer is not None:
            tokenizer = await self.get_tokenizer_async(lora_request)
            self._validate_token_prompt(prompt, tokenizer=tokenizer)

493
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
494
            prompt,
495
            lora_request=lora_request,
496
497
            prompt_adapter_request=prompt_adapter_request,
        )
498
        processed_inputs = self.input_processor(preprocessed_inputs)
499

500
501
502
503
504
505
506
507
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
508
                tokenizer=await self.get_tokenizer_async(lora_request),
509
                default_guided_backend=self.decoding_config.
510
                guided_decoding_backend,
511
                reasoning_backend=self.decoding_config.reasoning_backend,
512
                model_config=self.model_config)
513

514
        self._add_processed_request(
515
            request_id=request_id,
516
517
518
519
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
520
            prompt_adapter_request=prompt_adapter_request,
521
            trace_headers=trace_headers,
522
            priority=priority,
523
        )
524

525
    async def check_health_async(self) -> None:
526
527
        if self.tokenizer:
            self.tokenizer.check_health()
528
        self.model_executor.check_health()
529

530

531
532
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
533
        default_guided_backend: str, reasoning_backend: Optional[str],
534
        model_config: ModelConfig) -> SamplingParams:
535
536
537
538
539
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
540
    if sampling_params.guided_decoding is None:
541
542
        return sampling_params

543
544
545
546
547
    # Defensively copy sampling params since guided decoding logits
    # processors can have different state for each request
    sampling_params = copy.copy(sampling_params)
    guided_decoding = sampling_params.guided_decoding

548
    logger.debug(
549
550
551
552
        "Building guided decoding logits processor. "
        "guided_decoding: %s%s", guided_decoding,
        f", reasoning_backend: {reasoning_backend}"
        if reasoning_backend is not None else "")
553
554
555
556

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
557
558
        guided_params=guided_decoding,
        tokenizer=tokenizer,
559
        reasoning_backend=reasoning_backend,
560
        model_config=model_config)
561
562
563
564
565
566
567
568
569
570
571
572

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


573
class AsyncLLMEngine(EngineClient):
574
    """An asynchronous wrapper for :class:`LLMEngine`.
575

576
577
578
579
580
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
581
582

    Args:
583
        log_requests: Whether to log the requests.
584
585
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
586
587
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
588
    """
589

Antoni Baum's avatar
Antoni Baum committed
590
591
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

592
593
594
    def __init__(self,
                 *args,
                 log_requests: bool = True,
595
                 start_engine_loop: bool = True,
596
                 **kwargs) -> None:
597
598
599
600
601
602
603
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

604
        self.log_requests = log_requests
605
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
606

607
608
609
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
610
611
612
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

613
614
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
615
                weak_bind(self.process_request_outputs)
616

617
        self.background_loop: Optional[asyncio.Future] = None
618
619
620
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
621
        self._background_loop_unshielded: Optional[asyncio.Task] = None
622
        self.start_engine_loop = start_engine_loop
623
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
624

625
626
627
        # Lazy initialized fields
        self._request_tracker: RequestTracker

628
629
630
631
632
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

633
    @classmethod
634
635
636
    def _get_executor_cls(cls,
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
        return LLMEngine._get_executor_cls(engine_config)
637

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
        disable_log_requests: bool = False,
        disable_log_stats: bool = False,
    ) -> "AsyncLLMEngine":
        """Create an AsyncLLMEngine from the EngineArgs."""

        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            start_engine_loop=start_engine_loop,
            log_requests=not disable_log_requests,
            log_stats=not disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

660
661
662
663
664
665
666
667
668
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
669
670
671
672
673
674
675
676
677
678

        vllm_config = engine_args.create_engine_config(usage_context)

        async_engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
            async_engine_cls = V1AsyncLLMEngine

        return async_engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
679
680
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
681
            stat_loggers=stat_loggers,
682
683
            disable_log_stats=engine_args.disable_log_stats,
            disable_log_requests=engine_args.disable_log_requests,
yhu422's avatar
yhu422 committed
684
        )
685

686
687
    @property
    def is_running(self) -> bool:
688
        return (self.background_loop is not None
689
                and self._background_loop_unshielded is not None
690
691
692
693
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
694
695
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
696
697
698
699
700
701
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

702
    @property
703
704
705
706
707
708
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
709

710
711
712
713
714
715
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
716

717
718
719
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

720
721
722
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
723
    ) -> AnyTokenizer:
724
        return await self.engine.get_tokenizer_async(lora_request)
725

726
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
727
        """Start the background loop."""
728
729
730
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
731
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
732
            raise RuntimeError("Background loop is already running.")
733
734
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
735
736

        self._background_loop_unshielded = asyncio.get_event_loop(
737
        ).create_task(self.run_engine_loop(weakref.ref(self)))
738
        self._background_loop_unshielded.add_done_callback(
739
            partial(_log_task_completion, error_callback=self._error_callback))
740
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
741

742
743
744
745
746
747
748
749
750
751
752
753
754
755
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

756
    async def engine_step(self, virtual_engine: int) -> bool:
757
758
759
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
760

761
762
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
763
764
765

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
766
            try:
767
                await self.engine.add_request_async(**new_request)
768
769
770
771
772
773
774
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
775

776
777
        if aborted_requests:
            await self._engine_abort(aborted_requests)
778

779
        request_outputs = await self.engine.step_async(virtual_engine)
780

Antoni Baum's avatar
Antoni Baum committed
781
        # Put the outputs into the corresponding streams.
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
797
        for request_output in request_outputs:
798
            self._request_tracker.process_request_output(
799
                request_output, verbose=self.log_requests)
800
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
801

802
        return all_finished
803

Antoni Baum's avatar
Antoni Baum committed
804
    async def _engine_abort(self, request_ids: Iterable[str]):
805
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
806

807
808
809
810
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
811
        engine: Optional[AsyncLLMEngine] = engine_ref()
812
813
814
        if not engine:
            return

815
        pipeline_parallel_size = \
816
                engine.engine.parallel_config.pipeline_parallel_size
817
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
818
        while True:
819
            if not any(has_requests_in_progress):
820
                logger.debug("Waiting for new requests...")
821
822
823
824
825
826
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
827
828
829
830
831
832
833
834
835
836
837
838
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
839
                logger.debug("Got new requests!")
840
                requests_in_progress = [
841
                    asyncio.create_task(engine.engine_step(ve))
842
843
844
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
845
846
847
848

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
849
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
850
851
852
853
854
855
856
857
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
858
                    has_unfinished_requests = (
859
860
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
861
                            virtual_engine))
862
863
864
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
865
                                engine.engine_step(virtual_engine)))
866
867
868
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
869
870
871
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
872
                engine.set_errored(exc)
873
                raise
Antoni Baum's avatar
Antoni Baum committed
874
875
            await asyncio.sleep(0)

876
877
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
878
879
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
880
    def add_request(
881
882
        self,
        request_id: str,
883
884
        *,
        inputs: PromptType,
885
        params: Union[SamplingParams, PoolingParams],
886
887
888
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
889
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
890
        priority: int = 0,
891
    ) -> Coroutine[None, None, AsyncGenerator[Union[
892
            RequestOutput, PoolingRequestOutput], None]]:
893
894
895
896
897
898
899
900
901
902
903
904
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
905
        priority: int = 0,
906
    ) -> Coroutine[None, None, AsyncGenerator[Union[
907
            RequestOutput, PoolingRequestOutput], None]]:
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
923
        priority: int = 0,
924
925
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
926
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
927
928
929
930
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

931
        if not self.is_running:
932
933
934
935
936
937
938
939
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
940

941
942
943
944
945
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

946
        stream = self._request_tracker.add_request(
947
            request_id,
948
            verbose=self.log_requests,
949
            prompt=prompt,
950
            params=params,
951
            arrival_time=arrival_time or time.time(),
952
            lora_request=lora_request,
953
            trace_headers=trace_headers,
954
955
956
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
957

958
        return stream.generator()
959

960
    async def generate(
961
        self,
962
        prompt: PromptType,
963
964
        sampling_params: SamplingParams,
        request_id: str,
965
        lora_request: Optional[LoRARequest] = None,
966
        trace_headers: Optional[Mapping[str, str]] = None,
967
968
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
969
    ) -> AsyncGenerator[RequestOutput, None]:
970
971
972
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
973
974
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
975
976

        Args:
977
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
978
                for more details about the format of each input.
979
980
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
981
            lora_request: LoRA request to use for generation, if any.
982
            trace_headers: OpenTelemetry trace headers.
983
            prompt_adapter_request: Prompt Adapter request to use
984
                                            for generation, if any.
985
986
            priority: The priority of the request.
                Only applicable with priority scheduling.
987
988

        Yields:
989
990
            The output `RequestOutput` objects from the LLMEngine
            for the request.
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1008
            >>> # note that engine_args here is AsyncEngineArgs instance
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1035
        """
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    prompt_adapter_request=prompt_adapter_request,
                    priority=priority,
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1050
1051
1052

    async def encode(
        self,
1053
        prompt: PromptType,
1054
1055
1056
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1057
        trace_headers: Optional[Mapping[str, str]] = None,
1058
        priority: int = 0,
1059
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
1060
        """Generate outputs for a request from a pooling model.
1061
1062
1063
1064
1065
1066

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1067
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1068
                for more details about the format of each input.
1069
1070
1071
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1072
            trace_headers: OpenTelemetry trace headers.
1073
1074
            priority: The priority of the request.
                Only applicable with priority scheduling.
1075
1076

        Yields:
1077
            The output `PoolingRequestOutput` objects from the LLMEngine
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1096
            >>> # note that engine_args here is AsyncEngineArgs instance
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    pooling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
            ):
                yield LLMEngine.validate_output(output, PoolingRequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1135

Antoni Baum's avatar
Antoni Baum committed
1136
1137
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1138

Antoni Baum's avatar
Antoni Baum committed
1139
1140
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1141

Antoni Baum's avatar
Antoni Baum committed
1142
1143
1144
        Args:
            request_id: The unique id of the request.
        """
1145
1146
1147
1148
1149
1150
1151
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1152
        return self._abort(request_id)
1153

Antoni Baum's avatar
Antoni Baum committed
1154
    def _abort(self, request_id: str) -> None:
1155
1156
1157
1158
1159
1160
1161
1162
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1163
        self._request_tracker.abort_request(request_id,
1164
                                            exception=asyncio.CancelledError,
1165
                                            verbose=self.log_requests)
1166

1167
1168
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1169
        return self.engine.get_model_config()
1170

1171
1172
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1173
        return self.engine.get_parallel_config()
1174

1175
1176
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1177
        return self.engine.get_decoding_config()
1178

1179
1180
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1181
        return self.engine.get_scheduler_config()
1182
1183
1184

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1185
        return self.engine.get_lora_config()
1186

1187
1188
1189
1190
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1191
        self.engine.do_log_stats()
1192

1193
    async def check_health(self) -> None:
1194
1195
1196
1197
1198
1199
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1200
        await self.engine.check_health_async()
1201
        logger.debug("Health check took %fs", time.perf_counter() - t)
1202
1203

    async def is_tracing_enabled(self) -> bool:
1204
        return self.engine.is_tracing_enabled()
1205
1206

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1207
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1208
1209

    def remove_logger(self, logger_name: str) -> None:
1210
        self.engine.remove_logger(logger_name=logger_name)
1211
1212

    async def start_profile(self) -> None:
1213
        self.engine.start_profile()
1214
1215

    async def stop_profile(self) -> None:
1216
        self.engine.stop_profile()
1217

1218
1219
1220
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        self.engine.reset_prefix_cache(device)
1221

1222
1223
1224
1225
1226
1227
    async def sleep(self, level: int = 1) -> None:
        self.engine.sleep(level)

    async def wake_up(self) -> None:
        self.engine.wake_up()

1228
1229
1230
    async def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

1231
1232
1233
    async def add_lora(self, lora_request: LoRARequest) -> None:
        self.engine.add_lora(lora_request)

1234
1235

# TODO(v1): Remove this class proxy when V1 goes default.
1236
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
1237
1238
1239
    from vllm.v1.engine.async_llm import AsyncLLM

    AsyncLLMEngine = AsyncLLM  # type: ignore