async_llm_engine.py 49.1 KB
Newer Older
1
2
import asyncio
import time
3
import weakref
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
7
from weakref import ReferenceType
8

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
15
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
16
from vllm.engine.metrics_types import StatLoggerBase
17
from vllm.engine.protocol import EngineClient
18
from vllm.executor.executor_base import ExecutorAsyncBase
19
from vllm.executor.gpu_executor import GPUExecutorAsync
20
from vllm.executor.ray_utils import initialize_ray_cluster
21
from vllm.inputs import PromptType
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.logger import init_logger
23
from vllm.lora.request import LoRARequest
24
25
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
26
from vllm.model_executor.layers.sampler import SamplerOutput
27
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
28
from vllm.pooling_params import PoolingParams
29
from vllm.prompt_adapter.request import PromptAdapterRequest
30
from vllm.sampling_params import SamplingParams
31
from vllm.sequence import ExecuteModelRequest
32
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
33
from vllm.usage.usage_lib import UsageContext
34
from vllm.utils import deprecate_kwargs, weak_bind
35
36

logger = init_logger(__name__)
37
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
38

Antoni Baum's avatar
Antoni Baum committed
39

40
41
42
43
class AsyncEngineDeadError(RuntimeError):
    pass


44
45
46
47
48
49
50
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
51
52

    exception = None
53
    try:
54
55
56
57
58
59
60
61
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
62
63
64
65
66
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
67
            "Task finished unexpectedly. This should never happen! "
68
            "Please open an issue on Github. See stack trace above for the "
69
            "actual cause.") from e
70
71


72
73
74
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
75
class AsyncStream:
76
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
77
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
78

79
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
80
        self.request_id = request_id
81
        self._cancel = cancel
82
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
83
84
        self._finished = False

85
86
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
87
88
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
89

90
91
92
93
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
94
95
96
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
97
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
98
99
100
101
102

    @property
    def finished(self) -> bool:
        return self._finished

103
104
105
106
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
107
            while True:
108
                result = await self._queue.get()
109
                if self._is_raisable(result):
110
111
112
113
114
115
116
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
117

118
119
120
121
122
123
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
124

125
126
127
128
129
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
130
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
131
132
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
133
        self.new_requests_event = asyncio.Event()
134
135
136
137

    def __contains__(self, item):
        return item in self._request_streams

138
139
    def __len__(self) -> int:
        return len(self._request_streams)
140
141
142
143
144
145
146

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
147
            self.abort_request(request_id, exception=exc)
148
        else:
149
            # NB: tuple() used here because self.abort_request pops the stream
150
            # out of self._request_streams, so we can't iterate on it directly
151
152
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
153
154

    def process_request_output(self,
155
156
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
157
158
159
160
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
161
        finished = request_output.finished
162

163
164
165
166
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
167
168
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
169
        if stream is not None:
170
            stream.put(request_output)
171
172
173
174
175
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
176

177
178
    def process_exception(self,
                          request_id: str,
179
                          exception: BaseException,
180
181
182
183
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
184
            logger.info("Finished request %s.", request_id)
185
        self.abort_request(request_id, exception=exception)
186

187
188
189
190
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
191
192
193
194
195
196
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

197
198
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
199
200
201
202
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
203
204
205

        self.new_requests_event.set()

206
207
208
        if verbose:
            logger.info("Added request %s.", request_id)

209
210
        return stream

211
212
213
    def abort_request(self,
                      request_id: str,
                      *,
214
215
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
216
                      verbose: bool = False) -> None:
217
218
        """Abort a request during next background loop iteration."""
        if verbose:
219
            logger.info("Aborted request %s.", request_id)
220

221
        self._aborted_requests.put_nowait(request_id)
222

223
224
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
225
            stream.finish(exception=exception)
226

227
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
228
229
        """Get the new requests and finished requests to be
        sent to the engine."""
230
        new_requests: List[Dict] = []
231
232
        finished_requests: Set[str] = set()

233
234
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
235
236
237
238
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
239
240
            request_id = stream.request_id
            if request_id in finished_requests:
241
                # The request has already been aborted.
242
243
244
245
246
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
247
248

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
249

250
    async def wait_for_new_requests(self):
251
252
253
254
255
256
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
257

Antoni Baum's avatar
Antoni Baum committed
258
259
260
261

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

262
263
264
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

265
    async def step_async(
266
267
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
268
269
270
271
272
273
274
275
276
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
277
278
279
280
281
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
282
283
        allow_async_output_proc = cached_outputs.allow_async_output_proc

284
285
        ctx = self.scheduler_contexts[virtual_engine]

286
287
288
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

289
290
291
292
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
293

294
            # Schedule iteration
295
296
297
298
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

299
300
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
301
302

            # Maybe switch from async mode to sync mode
303
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
304
                self._process_model_outputs(ctx=ctx)
305

306
307
308
309
310
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
311
312
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
313
314
315

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
316

317
        if not scheduler_outputs.is_empty():
318
319
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
320
321
322
323
324
325
326
327

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

328
329
330
331
332
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
333
                virtual_engine=virtual_engine,
334
335
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
336
337
338
339
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
340
341

            if allow_async_output_proc:
342
343
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
344

345
            # Execute the model.
346
            outputs = await self.model_executor.execute_model_async(
347
                execute_model_req)
348

349
350
351
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
352
                self._update_cached_scheduler_output(virtual_engine, outputs)
353
        else:
354
355
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
356
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
357

358
359
360
361
362
363
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
364
            # Clear the cache if we have finished all the steps
365
366
367
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
368

369
370
371
372
373
374
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

375
376
377
378
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
379
380
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
381

382
            if outputs and allow_async_output_proc:
383
                assert len(
384
                    outputs
385
386
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
387
                    outputs[0], seq_group_metadata_list,
388
                    scheduler_outputs.scheduled_seq_groups)
389
390

            if not allow_async_output_proc:
391
                self._process_model_outputs(ctx=ctx)
392
393

                # Log stats.
394
                self.do_log_stats(scheduler_outputs, outputs)
395
396
397
398
399

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
400
            # Multi-step case
401
            return ctx.request_outputs
402
403
404
405

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
406
                self._process_model_outputs(ctx=ctx)
407
            assert len(ctx.output_queue) == 0
408

409
        return ctx.request_outputs
410

411
412
413
414
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

415
    @overload  # DEPRECATED
416
    async def add_request_async(
417
418
        self,
        request_id: str,
419
420
        *,
        inputs: PromptType,
421
422
423
424
425
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
426
        priority: int = 0,
427
428
429
430
431
432
433
434
435
436
437
438
439
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
440
        priority: int = 0,
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
457
            priority: int = 0,
458
459
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
460
    ) -> None:
461
        """Async version of :meth:`add_request`."""
462
463
464
465
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

466
467
468
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
469
470
471
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
472
473
        if arrival_time is None:
            arrival_time = time.time()
474

475
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
476
            prompt,
477
478
            request_id=request_id,
            lora_request=lora_request,
479
480
            prompt_adapter_request=prompt_adapter_request,
        )
481
        processed_inputs = self.input_processor(preprocessed_inputs)
482

483
484
485
486
487
488
489
490
491
492
493
494
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
                tokenizer=self.get_tokenizer(lora_request),
                default_guided_backend=self.decoding_config.
                guided_decoding_backend)

495
        self._add_processed_request(
496
            request_id=request_id,
497
498
499
500
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
501
            prompt_adapter_request=prompt_adapter_request,
502
            trace_headers=trace_headers,
503
            priority=priority,
504
        )
505

506
    async def check_health_async(self) -> None:
507
508
        if self.tokenizer:
            self.tokenizer.check_health()
509
        self.model_executor.check_health()
510

511

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
        default_guided_backend: str) -> SamplingParams:
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
    if (guided_decoding := sampling_params.guided_decoding) is None:
        return sampling_params

    logger.debug("Building guided decoding logits processor. "
                 "Params: %s", guided_decoding)

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
        guided_params=guided_decoding, tokenizer=tokenizer)

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


542
class AsyncLLMEngine(EngineClient):
543
    """An asynchronous wrapper for :class:`LLMEngine`.
544

545
546
547
548
549
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
550
551

    Args:
552
        log_requests: Whether to log the requests.
553
554
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
555
556
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
557
    """
558

Antoni Baum's avatar
Antoni Baum committed
559
560
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

561
562
563
    def __init__(self,
                 *args,
                 log_requests: bool = True,
564
                 start_engine_loop: bool = True,
565
                 **kwargs) -> None:
566
        self.log_requests = log_requests
567
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
568

569
570
571
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
572
573
574
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

575
576
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
577
                weak_bind(self.process_request_outputs)
578

579
        self.background_loop: Optional[asyncio.Future] = None
580
581
582
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
583
        self._background_loop_unshielded: Optional[asyncio.Task] = None
584
        self.start_engine_loop = start_engine_loop
585
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
586

587
588
589
        # Lazy initialized fields
        self._request_tracker: RequestTracker

590
591
592
593
594
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

595
    @classmethod
596
597
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
598
599
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
600
601
602
603
604
605
606
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
607
608
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
609
        elif engine_config.device_config.device_type == "tpu":
610
611
612
613
614
615
616
            if distributed_executor_backend == "ray":
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
617
618
619
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
620
621
622
623
624
625
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
626
627
628
629
630
631
632
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
633
634
635
636
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
637
638
639
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
640
        elif distributed_executor_backend == "ray":
641
642
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
643
644
645
646
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
647
648
649
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
650
651
652
653
654
655
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
656
        engine_config: Optional[EngineConfig] = None,
657
658
659
660
661
662
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
663
664
        if engine_config is None:
            engine_config = engine_args.create_engine_config()
665
666
667

        executor_class = cls._get_executor_cls(engine_config)

668
669
670
        if executor_class.uses_ray:
            initialize_ray_cluster(engine_config.parallel_config)

671
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
672
        engine = cls(
673
674
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
675
676
677
678
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
679
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
680
        )
681
682
        return engine

683
684
    @property
    def is_running(self) -> bool:
685
        return (self.background_loop is not None
686
                and self._background_loop_unshielded is not None
687
688
689
690
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
691
692
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
693
694
695
696
697
698
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

699
    @property
700
701
702
703
704
705
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
706

707
708
709
710
711
712
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
713

714
715
716
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
717
    ) -> AnyTokenizer:
718
719
        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
720

721
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
722
        """Start the background loop."""
723
724
725
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
726
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
727
            raise RuntimeError("Background loop is already running.")
728
729
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
730
731

        self._background_loop_unshielded = asyncio.get_event_loop(
732
        ).create_task(self.run_engine_loop(weakref.ref(self)))
733
        self._background_loop_unshielded.add_done_callback(
734
            partial(_log_task_completion, error_callback=self._error_callback))
735
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
736

737
738
739
740
741
742
743
744
745
746
747
748
749
750
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

751
    async def engine_step(self, virtual_engine: int) -> bool:
752
753
754
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
755

756
757
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
758
759
760

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
761
            try:
762
                await self.engine.add_request_async(**new_request)
763
764
765
766
767
768
769
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
770

771
772
        if aborted_requests:
            await self._engine_abort(aborted_requests)
773

774
        request_outputs = await self.engine.step_async(virtual_engine)
775

Antoni Baum's avatar
Antoni Baum committed
776
        # Put the outputs into the corresponding streams.
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
792
        for request_output in request_outputs:
793
            self._request_tracker.process_request_output(
794
                request_output, verbose=self.log_requests)
795
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
796

797
        return all_finished
798

Antoni Baum's avatar
Antoni Baum committed
799
    async def _engine_abort(self, request_ids: Iterable[str]):
800
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
801

802
803
804
805
806
807
808
809
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
        engine: Optional["AsyncLLMEngine"] = engine_ref()
        if not engine:
            return

810
        pipeline_parallel_size = \
811
                engine.engine.parallel_config.pipeline_parallel_size
812
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
813
        while True:
814
            if not any(has_requests_in_progress):
815
                logger.debug("Waiting for new requests...")
816
817
818
819
820
821
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
822
823
824
825
826
827
828
829
830
831
832
833
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
834
                logger.debug("Got new requests!")
835
                requests_in_progress = [
836
                    asyncio.create_task(engine.engine_step(ve))
837
838
839
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
840
841
842
843

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
844
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
845
846
847
848
849
850
851
852
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
853
                    has_unfinished_requests = (
854
855
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
856
                            virtual_engine))
857
858
859
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
860
                                engine.engine_step(virtual_engine)))
861
862
863
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
864
865
866
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
867
                engine.set_errored(exc)
868
                raise
Antoni Baum's avatar
Antoni Baum committed
869
870
            await asyncio.sleep(0)

871
872
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
873
874
    @overload  # DEPRECATED
    def add_request(
875
876
        self,
        request_id: str,
877
878
        *,
        inputs: PromptType,
879
        params: Union[SamplingParams, PoolingParams],
880
881
882
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
883
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
884
        priority: int = 0,
885
886
887
888
889
890
891
892
893
894
895
896
897
898
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
899
        priority: int = 0,
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
917
        priority: int = 0,
918
919
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
920
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
921
922
923
924
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

925
        if not self.is_running:
926
927
928
929
930
931
932
933
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
934

935
936
937
938
939
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

940
        stream = self._request_tracker.add_request(
941
            request_id,
942
            verbose=self.log_requests,
943
            prompt=prompt,
944
            params=params,
945
            arrival_time=arrival_time or time.time(),
946
            lora_request=lora_request,
947
            trace_headers=trace_headers,
948
949
950
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
951

952
        return stream.generator()
953

954
    async def generate(
955
        self,
956
        prompt: PromptType,
957
958
        sampling_params: SamplingParams,
        request_id: str,
959
        lora_request: Optional[LoRARequest] = None,
960
        trace_headers: Optional[Mapping[str, str]] = None,
961
962
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
963
    ) -> AsyncGenerator[RequestOutput, None]:
964
965
966
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
967
968
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
969
970

        Args:
971
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
972
                for more details about the format of each input.
973
974
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
975
            lora_request: LoRA request to use for generation, if any.
976
            trace_headers: OpenTelemetry trace headers.
977
            prompt_adapter_request: Prompt Adapter request to use
978
                                            for generation, if any.
979
980
            priority: The priority of the request.
                Only applicable with priority scheduling.
981
982

        Yields:
983
984
            The output `RequestOutput` objects from the LLMEngine
            for the request.
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1028
        """
1029
        async for output in await self.add_request(
1030
                request_id,
1031
                prompt,
1032
                sampling_params,
1033
                lora_request=lora_request,
1034
                trace_headers=trace_headers,
1035
                prompt_adapter_request=prompt_adapter_request,
1036
                priority=priority,
1037
        ):
1038
            yield LLMEngine.validate_output(output, RequestOutput)
1039
1040
1041

    async def encode(
        self,
1042
        prompt: PromptType,
1043
1044
1045
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1046
        trace_headers: Optional[Mapping[str, str]] = None,
1047
        priority: int = 0,
1048
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
1049
1050
1051
1052
1053
1054
1055
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1056
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1057
                for more details about the format of each input.
1058
1059
1060
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1061
            trace_headers: OpenTelemetry trace headers.
1062
1063
            priority: The priority of the request.
                Only applicable with priority scheduling.
1064
1065

        Yields:
1066
            The output `EmbeddingRequestOutput` objects from the LLMEngine
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1110
        async for output in await self.add_request(
1111
                request_id,
1112
                prompt,
1113
                pooling_params,
1114
                lora_request=lora_request,
1115
                trace_headers=trace_headers,
1116
                priority=priority,
1117
        ):
1118
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1119

Antoni Baum's avatar
Antoni Baum committed
1120
1121
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1122

Antoni Baum's avatar
Antoni Baum committed
1123
1124
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1125

Antoni Baum's avatar
Antoni Baum committed
1126
1127
1128
        Args:
            request_id: The unique id of the request.
        """
1129
1130
1131
1132
1133
1134
1135
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1136
        return self._abort(request_id)
1137

Antoni Baum's avatar
Antoni Baum committed
1138
    def _abort(self, request_id: str) -> None:
1139
1140
1141
1142
1143
1144
1145
1146
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1147
        self._request_tracker.abort_request(request_id,
1148
                                            exception=asyncio.CancelledError,
1149
                                            verbose=self.log_requests)
1150

1151
1152
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1153
        return self.engine.get_model_config()
1154

1155
1156
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1157
        return self.engine.get_parallel_config()
1158

1159
1160
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1161
        return self.engine.get_decoding_config()
1162

1163
1164
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1165
        return self.engine.get_scheduler_config()
1166
1167
1168

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1169
        return self.engine.get_lora_config()
1170

1171
1172
1173
1174
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1175
        self.engine.do_log_stats()
1176

1177
    async def check_health(self) -> None:
1178
1179
1180
1181
1182
1183
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1184
        await self.engine.check_health_async()
1185
        logger.debug("Health check took %fs", time.perf_counter() - t)
1186
1187

    async def is_tracing_enabled(self) -> bool:
1188
        return self.engine.is_tracing_enabled()
1189
1190

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1191
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1192
1193

    def remove_logger(self, logger_name: str) -> None:
1194
        self.engine.remove_logger(logger_name=logger_name)
1195
1196

    async def start_profile(self) -> None:
1197
1198
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1199
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1200
1201
1202
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")
1203
1204

    async def stop_profile(self) -> None:
1205
1206
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1207
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1208
1209
1210
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")