async_llm_engine.py 47.3 KB
Newer Older
1
import asyncio
2
import copy
3
import time
4
import weakref
Antoni Baum's avatar
Antoni Baum committed
5
from functools import partial
6
7
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
8
from weakref import ReferenceType
9

10
11
from typing_extensions import deprecated

12
import vllm.envs as envs
13
14
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
15
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.engine.arg_utils import AsyncEngineArgs
17
from vllm.engine.async_timeout import asyncio_timeout
18
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
19
from vllm.engine.metrics_types import StatLoggerBase
20
from vllm.engine.protocol import EngineClient
21
from vllm.executor.executor_base import ExecutorBase
22
from vllm.inputs import PromptType
23
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.logger import init_logger
25
from vllm.lora.request import LoRARequest
26
27
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
28
from vllm.model_executor.layers.sampler import SamplerOutput
29
from vllm.outputs import PoolingRequestOutput, RequestOutput
30
from vllm.pooling_params import PoolingParams
31
from vllm.prompt_adapter.request import PromptAdapterRequest
32
from vllm.sampling_params import SamplingParams
33
from vllm.sequence import ExecuteModelRequest
34
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
35
from vllm.usage.usage_lib import UsageContext
36
from vllm.utils import deprecate_kwargs, weak_bind
37
38

logger = init_logger(__name__)
39
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
40

Antoni Baum's avatar
Antoni Baum committed
41

42
43
44
45
class AsyncEngineDeadError(RuntimeError):
    pass


46
47
48
49
50
51
52
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
53
54

    exception = None
55
    try:
56
57
58
59
60
61
62
63
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
64
65
66
67
68
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
69
            "Task finished unexpectedly. This should never happen! "
70
            "Please open an issue on Github. See stack trace above for the "
71
            "actual cause.") from e
72
73


74
75
76
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
77
class AsyncStream:
78
    """A stream of RequestOutputs or PoolingRequestOutputs for a request
79
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
80

81
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
82
        self.request_id = request_id
83
        self._cancel = cancel
84
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
85
86
        self._finished = False

87
    def put(self, item: Union[RequestOutput, PoolingRequestOutput,
88
                              Exception]) -> None:
89
90
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
91

92
93
94
95
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
96
97
98
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
99
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
100
101
102
103
104

    @property
    def finished(self) -> bool:
        return self._finished

105
106
    async def generator(
        self
107
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
108
        try:
109
            while True:
110
                result = await self._queue.get()
111
                if self._is_raisable(result):
112
113
114
115
116
117
118
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
119

120
121
122
123
124
125
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
126

127
128
129
130
131
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
132
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
133
134
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
135
        self.new_requests_event = asyncio.Event()
136
137
138
139

    def __contains__(self, item):
        return item in self._request_streams

140
141
    def __len__(self) -> int:
        return len(self._request_streams)
142
143
144
145
146
147
148

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
149
            self.abort_request(request_id, exception=exc)
150
        else:
151
            # NB: tuple() used here because self.abort_request pops the stream
152
            # out of self._request_streams, so we can't iterate on it directly
153
154
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
155
156

    def process_request_output(self,
157
                               request_output: Union[RequestOutput,
158
                                                     PoolingRequestOutput],
159
160
161
162
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
163
        finished = request_output.finished
164

165
166
167
168
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
169
170
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
171
        if stream is not None:
172
            stream.put(request_output)
173
174
175
176
177
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
178

179
180
    def process_exception(self,
                          request_id: str,
181
                          exception: BaseException,
182
183
184
185
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
186
            logger.info("Finished request %s.", request_id)
187
        self.abort_request(request_id, exception=exception)
188

189
190
191
192
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
193
194
195
196
197
198
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

199
200
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
201
202
203
204
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
205
206
207

        self.new_requests_event.set()

208
209
210
        if verbose:
            logger.info("Added request %s.", request_id)

211
212
        return stream

213
214
215
    def abort_request(self,
                      request_id: str,
                      *,
216
217
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
218
                      verbose: bool = False) -> None:
219
220
        """Abort a request during next background loop iteration."""
        if verbose:
221
            logger.info("Aborted request %s.", request_id)
222

223
        self._aborted_requests.put_nowait(request_id)
224

225
226
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
227
            stream.finish(exception=exception)
228

229
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
230
231
        """Get the new requests and finished requests to be
        sent to the engine."""
232
        new_requests: List[Dict] = []
233
234
        finished_requests: Set[str] = set()

235
236
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
237
238
239
240
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
241
242
            request_id = stream.request_id
            if request_id in finished_requests:
243
                # The request has already been aborted.
244
245
246
247
248
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
249
250

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
251

252
    async def wait_for_new_requests(self):
253
254
255
256
257
258
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
259

Antoni Baum's avatar
Antoni Baum committed
260
261
262
263

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

264
265
266
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

267
    async def step_async(
268
        self, virtual_engine: int
269
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
270
271
272
273
274
275
276
277
278
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
279
280
281
282
283
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
284
285
        allow_async_output_proc = cached_outputs.allow_async_output_proc

286
287
        ctx = self.scheduler_contexts[virtual_engine]

288
289
290
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

291
292
293
294
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
295

296
            # Schedule iteration
297
298
299
300
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

301
302
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
303

304
305
306
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()

307
            # Maybe switch from async mode to sync mode
308
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
309
                self._process_model_outputs(ctx=ctx)
310

311
312
313
314
315
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
316
317
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
318
319
        else:
            finished_requests_ids = list()
320
321
322

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
323

324
        if not scheduler_outputs.is_empty():
325
326
327
328
329
330
331
332

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

333
334
335
336
337
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
338
                virtual_engine=virtual_engine,
339
340
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
341
342
343
344
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
345
346

            if allow_async_output_proc:
347
348
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
349

350
            # Execute the model.
351
            outputs = await self.model_executor.execute_model_async(
352
                execute_model_req)
353

354
355
356
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
357
                self._update_cached_scheduler_output(virtual_engine, outputs)
358
        else:
359
360
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
361
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
362

363
364
365
366
367
368
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
369
            # Clear the cache if we have finished all the steps
370
371
372
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
373

374
375
376
377
378
379
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

380
381
382
383
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
384
385
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
386

387
            if outputs and allow_async_output_proc:
388
                assert len(
389
                    outputs
390
391
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
392
                    outputs[0], seq_group_metadata_list,
393
                    scheduler_outputs.scheduled_seq_groups)
394
395

            if not allow_async_output_proc:
396
                self._process_model_outputs(ctx=ctx)
397
398

                # Log stats.
399
                self.do_log_stats(scheduler_outputs, outputs)
400
401
402
403
404

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
405
            # Multi-step case
406
            return ctx.request_outputs
407
408
409
410

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
411
                self._process_model_outputs(ctx=ctx)
412
            assert len(ctx.output_queue) == 0
413

414
        return ctx.request_outputs
415

416
417
418
419
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

420
421
422
423
424
425
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

426
427
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
428
    async def add_request_async(
429
430
        self,
        request_id: str,
431
432
        *,
        inputs: PromptType,
433
434
435
436
437
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
438
        priority: int = 0,
439
440
441
442
443
444
445
446
447
448
449
450
451
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
452
        priority: int = 0,
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
469
            priority: int = 0,
470
471
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
472
    ) -> None:
473
        """Async version of :meth:`add_request`."""
474
475
476
477
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

478
479
480
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
481
482
483
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
484
485
        if arrival_time is None:
            arrival_time = time.time()
486

487
488
489
490
        if self.tokenizer is not None:
            tokenizer = await self.get_tokenizer_async(lora_request)
            self._validate_token_prompt(prompt, tokenizer=tokenizer)

491
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
492
            prompt,
493
494
            request_id=request_id,
            lora_request=lora_request,
495
496
            prompt_adapter_request=prompt_adapter_request,
        )
497
        processed_inputs = self.input_processor(preprocessed_inputs)
498

499
500
501
502
503
504
505
506
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
507
                tokenizer=await self.get_tokenizer_async(lora_request),
508
                default_guided_backend=self.decoding_config.
509
510
                guided_decoding_backend,
                model_config=self.model_config)
511

512
        self._add_processed_request(
513
            request_id=request_id,
514
515
516
517
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
518
            prompt_adapter_request=prompt_adapter_request,
519
            trace_headers=trace_headers,
520
            priority=priority,
521
        )
522

523
    async def check_health_async(self) -> None:
524
525
        if self.tokenizer:
            self.tokenizer.check_health()
526
        self.model_executor.check_health()
527

528

529
530
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
531
532
        default_guided_backend: str,
        model_config: ModelConfig) -> SamplingParams:
533
534
535
536
537
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
538
    if sampling_params.guided_decoding is None:
539
540
        return sampling_params

541
542
543
544
545
    # Defensively copy sampling params since guided decoding logits
    # processors can have different state for each request
    sampling_params = copy.copy(sampling_params)
    guided_decoding = sampling_params.guided_decoding

546
547
548
549
550
551
    logger.debug("Building guided decoding logits processor. "
                 "Params: %s", guided_decoding)

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
552
553
554
        guided_params=guided_decoding,
        tokenizer=tokenizer,
        model_config=model_config)
555
556
557
558
559
560
561
562
563
564
565
566

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


567
class AsyncLLMEngine(EngineClient):
568
    """An asynchronous wrapper for :class:`LLMEngine`.
569

570
571
572
573
574
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
575
576

    Args:
577
        log_requests: Whether to log the requests.
578
579
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
580
581
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
582
    """
583

Antoni Baum's avatar
Antoni Baum committed
584
585
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

586
587
588
    def __init__(self,
                 *args,
                 log_requests: bool = True,
589
                 start_engine_loop: bool = True,
590
                 **kwargs) -> None:
591
        self.log_requests = log_requests
592
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
593

594
595
596
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
597
598
599
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

600
601
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
602
                weak_bind(self.process_request_outputs)
603

604
        self.background_loop: Optional[asyncio.Future] = None
605
606
607
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
608
        self._background_loop_unshielded: Optional[asyncio.Task] = None
609
        self.start_engine_loop = start_engine_loop
610
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
611

612
613
614
        # Lazy initialized fields
        self._request_tracker: RequestTracker

615
616
617
618
619
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

620
    @classmethod
621
622
623
    def _get_executor_cls(cls,
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
        return LLMEngine._get_executor_cls(engine_config)
624
625
626
627
628

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
629
        engine_config: Optional[VllmConfig] = None,
630
631
632
633
634
635
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
636
        if engine_config is None:
637
            engine_config = engine_args.create_engine_config(usage_context)
638
639
640

        executor_class = cls._get_executor_cls(engine_config)

641
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
642
        engine = cls(
643
            vllm_config=engine_config,
644
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
645
646
647
648
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
649
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
650
        )
651
652
        return engine

653
654
    @property
    def is_running(self) -> bool:
655
        return (self.background_loop is not None
656
                and self._background_loop_unshielded is not None
657
658
659
660
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
661
662
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
663
664
665
666
667
668
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

669
    @property
670
671
672
673
674
675
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
676

677
678
679
680
681
682
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
683

684
685
686
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

687
688
689
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
690
    ) -> AnyTokenizer:
691
        return await self.engine.get_tokenizer_async(lora_request)
692

693
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
694
        """Start the background loop."""
695
696
697
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
698
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
699
            raise RuntimeError("Background loop is already running.")
700
701
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
702
703

        self._background_loop_unshielded = asyncio.get_event_loop(
704
        ).create_task(self.run_engine_loop(weakref.ref(self)))
705
        self._background_loop_unshielded.add_done_callback(
706
            partial(_log_task_completion, error_callback=self._error_callback))
707
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
708

709
710
711
712
713
714
715
716
717
718
719
720
721
722
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

723
    async def engine_step(self, virtual_engine: int) -> bool:
724
725
726
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
727

728
729
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
730
731
732

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
733
            try:
734
                await self.engine.add_request_async(**new_request)
735
736
737
738
739
740
741
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
742

743
744
        if aborted_requests:
            await self._engine_abort(aborted_requests)
745

746
        request_outputs = await self.engine.step_async(virtual_engine)
747

Antoni Baum's avatar
Antoni Baum committed
748
        # Put the outputs into the corresponding streams.
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
764
        for request_output in request_outputs:
765
            self._request_tracker.process_request_output(
766
                request_output, verbose=self.log_requests)
767
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
768

769
        return all_finished
770

Antoni Baum's avatar
Antoni Baum committed
771
    async def _engine_abort(self, request_ids: Iterable[str]):
772
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
773

774
775
776
777
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
778
        engine: Optional[AsyncLLMEngine] = engine_ref()
779
780
781
        if not engine:
            return

782
        pipeline_parallel_size = \
783
                engine.engine.parallel_config.pipeline_parallel_size
784
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
785
        while True:
786
            if not any(has_requests_in_progress):
787
                logger.debug("Waiting for new requests...")
788
789
790
791
792
793
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
794
795
796
797
798
799
800
801
802
803
804
805
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
806
                logger.debug("Got new requests!")
807
                requests_in_progress = [
808
                    asyncio.create_task(engine.engine_step(ve))
809
810
811
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
812
813
814
815

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
816
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
817
818
819
820
821
822
823
824
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
825
                    has_unfinished_requests = (
826
827
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
828
                            virtual_engine))
829
830
831
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
832
                                engine.engine_step(virtual_engine)))
833
834
835
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
836
837
838
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
839
                engine.set_errored(exc)
840
                raise
Antoni Baum's avatar
Antoni Baum committed
841
842
            await asyncio.sleep(0)

843
844
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
845
846
    @overload
    @deprecated("'inputs' will be renamed to 'prompt")
847
    def add_request(
848
849
        self,
        request_id: str,
850
851
        *,
        inputs: PromptType,
852
        params: Union[SamplingParams, PoolingParams],
853
854
855
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
856
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
857
        priority: int = 0,
858
    ) -> Coroutine[None, None, AsyncGenerator[Union[
859
            RequestOutput, PoolingRequestOutput], None]]:
860
861
862
863
864
865
866
867
868
869
870
871
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
872
        priority: int = 0,
873
    ) -> Coroutine[None, None, AsyncGenerator[Union[
874
            RequestOutput, PoolingRequestOutput], None]]:
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
890
        priority: int = 0,
891
892
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
893
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
894
895
896
897
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

898
        if not self.is_running:
899
900
901
902
903
904
905
906
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
907

908
909
910
911
912
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

913
        stream = self._request_tracker.add_request(
914
            request_id,
915
            verbose=self.log_requests,
916
            prompt=prompt,
917
            params=params,
918
            arrival_time=arrival_time or time.time(),
919
            lora_request=lora_request,
920
            trace_headers=trace_headers,
921
922
923
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
924

925
        return stream.generator()
926

927
    async def generate(
928
        self,
929
        prompt: PromptType,
930
931
        sampling_params: SamplingParams,
        request_id: str,
932
        lora_request: Optional[LoRARequest] = None,
933
        trace_headers: Optional[Mapping[str, str]] = None,
934
935
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
936
    ) -> AsyncGenerator[RequestOutput, None]:
937
938
939
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
940
941
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
942
943

        Args:
944
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
945
                for more details about the format of each input.
946
947
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
948
            lora_request: LoRA request to use for generation, if any.
949
            trace_headers: OpenTelemetry trace headers.
950
            prompt_adapter_request: Prompt Adapter request to use
951
                                            for generation, if any.
952
953
            priority: The priority of the request.
                Only applicable with priority scheduling.
954
955

        Yields:
956
957
            The output `RequestOutput` objects from the LLMEngine
            for the request.
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
975
            >>> # note that engine_args here is AsyncEngineArgs instance
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1002
        """
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    prompt_adapter_request=prompt_adapter_request,
                    priority=priority,
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1017
1018
1019

    async def encode(
        self,
1020
        prompt: PromptType,
1021
1022
1023
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1024
        trace_headers: Optional[Mapping[str, str]] = None,
1025
        priority: int = 0,
1026
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
1027
        """Generate outputs for a request from a pooling model.
1028
1029
1030
1031
1032
1033

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1034
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1035
                for more details about the format of each input.
1036
1037
1038
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1039
            trace_headers: OpenTelemetry trace headers.
1040
1041
            priority: The priority of the request.
                Only applicable with priority scheduling.
1042
1043

        Yields:
1044
            The output `PoolingRequestOutput` objects from the LLMEngine
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
1063
            >>> # note that engine_args here is AsyncEngineArgs instance
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    pooling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
            ):
                yield LLMEngine.validate_output(output, PoolingRequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1102

Antoni Baum's avatar
Antoni Baum committed
1103
1104
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1105

Antoni Baum's avatar
Antoni Baum committed
1106
1107
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1108

Antoni Baum's avatar
Antoni Baum committed
1109
1110
1111
        Args:
            request_id: The unique id of the request.
        """
1112
1113
1114
1115
1116
1117
1118
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1119
        return self._abort(request_id)
1120

Antoni Baum's avatar
Antoni Baum committed
1121
    def _abort(self, request_id: str) -> None:
1122
1123
1124
1125
1126
1127
1128
1129
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1130
        self._request_tracker.abort_request(request_id,
1131
                                            exception=asyncio.CancelledError,
1132
                                            verbose=self.log_requests)
1133

1134
1135
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1136
        return self.engine.get_model_config()
1137

1138
1139
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1140
        return self.engine.get_parallel_config()
1141

1142
1143
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1144
        return self.engine.get_decoding_config()
1145

1146
1147
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1148
        return self.engine.get_scheduler_config()
1149
1150
1151

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1152
        return self.engine.get_lora_config()
1153

1154
1155
1156
1157
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1158
        self.engine.do_log_stats()
1159

1160
    async def check_health(self) -> None:
1161
1162
1163
1164
1165
1166
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1167
        await self.engine.check_health_async()
1168
        logger.debug("Health check took %fs", time.perf_counter() - t)
1169
1170

    async def is_tracing_enabled(self) -> bool:
1171
        return self.engine.is_tracing_enabled()
1172
1173

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1174
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1175
1176

    def remove_logger(self, logger_name: str) -> None:
1177
        self.engine.remove_logger(logger_name=logger_name)
1178
1179

    async def start_profile(self) -> None:
1180
        self.engine.start_profile()
1181
1182

    async def stop_profile(self) -> None:
1183
        self.engine.stop_profile()
1184

1185
1186
1187
    async def add_lora(self, lora_request: LoRARequest) -> None:
        self.engine.add_lora(lora_request)

1188
1189
1190
1191
1192
1193

# TODO(v1): Remove this class proxy when V1 goes default.
if envs.VLLM_USE_V1:
    from vllm.v1.engine.async_llm import AsyncLLM

    AsyncLLMEngine = AsyncLLM  # type: ignore