async_llm_engine.py 49.1 KB
Newer Older
1
2
import asyncio
import time
3
import weakref
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
7
from weakref import ReferenceType
8

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
15
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
16
from vllm.engine.metrics_types import StatLoggerBase
17
from vllm.executor.executor_base import ExecutorAsyncBase
18
from vllm.executor.gpu_executor import GPUExecutorAsync
19
from vllm.executor.ray_utils import initialize_ray_cluster
20
from vllm.inputs import PromptType
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
23
24
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
25
from vllm.model_executor.layers.sampler import SamplerOutput
26
27
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
28
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.sampling_params import SamplingParams
30
from vllm.sequence import ExecuteModelRequest
31
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
32
from vllm.usage.usage_lib import UsageContext
33
from vllm.utils import deprecate_kwargs, weak_bind
34
35

logger = init_logger(__name__)
36
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
37

Antoni Baum's avatar
Antoni Baum committed
38

39
40
41
42
class AsyncEngineDeadError(RuntimeError):
    pass


43
44
45
46
47
48
49
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
50
51

    exception = None
52
    try:
53
54
55
56
57
58
59
60
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
61
62
63
64
65
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
66
            "Task finished unexpectedly. This should never happen! "
67
            "Please open an issue on Github. See stack trace above for the "
68
            "actual cause.") from e
69
70


71
72
73
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
74
class AsyncStream:
75
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
76
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
77

78
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
79
        self.request_id = request_id
80
        self._cancel = cancel
81
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
82
83
        self._finished = False

84
85
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
86
87
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
88

89
90
91
92
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
93
94
95
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
96
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
97
98
99
100
101

    @property
    def finished(self) -> bool:
        return self._finished

102
103
104
105
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
106
            while True:
107
                result = await self._queue.get()
108
                if self._is_raisable(result):
109
110
111
112
113
114
115
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
116

117
118
119
120
121
122
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
123

124
125
126
127
128
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
129
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
130
131
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
132
        self.new_requests_event = asyncio.Event()
133
134
135
136

    def __contains__(self, item):
        return item in self._request_streams

137
138
    def __len__(self) -> int:
        return len(self._request_streams)
139
140
141
142
143
144
145

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
146
            self.abort_request(request_id, exception=exc)
147
        else:
148
            # NB: tuple() used here because self.abort_request pops the stream
149
            # out of self._request_streams, so we can't iterate on it directly
150
151
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
152
153

    def process_request_output(self,
154
155
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
156
157
158
159
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
160
        finished = request_output.finished
161

162
163
164
165
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
166
167
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
168
        if stream is not None:
169
            stream.put(request_output)
170
171
172
173
174
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
175

176
177
    def process_exception(self,
                          request_id: str,
178
                          exception: BaseException,
179
180
181
182
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
183
            logger.info("Finished request %s.", request_id)
184
        self.abort_request(request_id, exception=exception)
185

186
187
188
189
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
190
191
192
193
194
195
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

196
197
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
198
199
200
201
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
202
203
204

        self.new_requests_event.set()

205
206
207
        if verbose:
            logger.info("Added request %s.", request_id)

208
209
        return stream

210
211
212
    def abort_request(self,
                      request_id: str,
                      *,
213
214
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
215
                      verbose: bool = False) -> None:
216
217
        """Abort a request during next background loop iteration."""
        if verbose:
218
            logger.info("Aborted request %s.", request_id)
219

220
        self._aborted_requests.put_nowait(request_id)
221

222
223
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
224
            stream.finish(exception=exception)
225

226
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
227
228
        """Get the new requests and finished requests to be
        sent to the engine."""
229
        new_requests: List[Dict] = []
230
231
        finished_requests: Set[str] = set()

232
233
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
234
235
236
237
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
238
239
            request_id = stream.request_id
            if request_id in finished_requests:
240
                # The request has already been aborted.
241
242
243
244
245
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
246
247

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
248

249
    async def wait_for_new_requests(self):
250
251
252
253
254
255
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
256

Antoni Baum's avatar
Antoni Baum committed
257
258
259
260

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

261
262
263
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

264
    async def step_async(
265
266
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
267
268
269
270
271
272
273
274
275
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
276
277
278
279
280
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
281
282
        allow_async_output_proc = cached_outputs.allow_async_output_proc

283
284
        ctx = self.scheduler_contexts[virtual_engine]

285
286
287
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

288
289
290
291
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
292

293
            # Schedule iteration
294
295
296
297
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

298
299
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
300
301

            # Maybe switch from async mode to sync mode
302
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
303
                self._process_model_outputs(ctx=ctx)
304

305
306
307
308
309
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
310
311
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
312
313
314

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
315

316
        if not scheduler_outputs.is_empty():
317
318
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
319
320
321
322
323
324
325
326

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

327
328
329
330
331
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
332
                virtual_engine=virtual_engine,
333
334
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
335
336
337
338
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
339
340

            if allow_async_output_proc:
341
342
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
343

344
            # Execute the model.
345
            outputs = await self.model_executor.execute_model_async(
346
                execute_model_req)
347

348
349
350
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
351
                self._update_cached_scheduler_output(virtual_engine, outputs)
352
        else:
353
354
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
355
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
356

357
358
359
360
361
362
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
363
            # Clear the cache if we have finished all the steps
364
365
366
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
367

368
369
370
371
372
373
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

374
375
376
377
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
378
379
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
380

381
            if outputs and allow_async_output_proc:
382
                assert len(
383
                    outputs
384
385
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
386
                    outputs[0], seq_group_metadata_list,
387
                    scheduler_outputs.scheduled_seq_groups)
388
389

            if not allow_async_output_proc:
390
                self._process_model_outputs(ctx=ctx)
391
392

                # Log stats.
393
                self.do_log_stats(scheduler_outputs, outputs)
394
395
396
397
398

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
399
            # Multi-step case
400
            return ctx.request_outputs
401
402
403
404

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
405
                self._process_model_outputs(ctx=ctx)
406
            assert len(ctx.output_queue) == 0
407

408
        return ctx.request_outputs
409

410
411
412
413
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

414
    @overload  # DEPRECATED
415
    async def add_request_async(
416
417
        self,
        request_id: str,
418
419
        *,
        inputs: PromptType,
420
421
422
423
424
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
425
        priority: int = 0,
426
427
428
429
430
431
432
433
434
435
436
437
438
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
439
        priority: int = 0,
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
456
            priority: int = 0,
457
458
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
459
    ) -> None:
460
        """Async version of :meth:`add_request`."""
461
462
463
464
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

465
466
467
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
468
469
470
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
471
472
        if arrival_time is None:
            arrival_time = time.time()
473

474
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
475
            prompt,
476
477
            request_id=request_id,
            lora_request=lora_request,
478
479
            prompt_adapter_request=prompt_adapter_request,
        )
480
        processed_inputs = self.input_processor(preprocessed_inputs)
481

482
483
484
485
486
487
488
489
490
491
492
493
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
                tokenizer=self.get_tokenizer(lora_request),
                default_guided_backend=self.decoding_config.
                guided_decoding_backend)

494
        self._add_processed_request(
495
            request_id=request_id,
496
497
498
499
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
500
            prompt_adapter_request=prompt_adapter_request,
501
            trace_headers=trace_headers,
502
            priority=priority,
503
        )
504

505
    async def check_health_async(self) -> None:
506
507
        if self.tokenizer:
            self.tokenizer.check_health()
508
        self.model_executor.check_health()
509

510

511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
        default_guided_backend: str) -> SamplingParams:
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
    if (guided_decoding := sampling_params.guided_decoding) is None:
        return sampling_params

    logger.debug("Building guided decoding logits processor. "
                 "Params: %s", guided_decoding)

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
        guided_params=guided_decoding, tokenizer=tokenizer)

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


541
class AsyncLLMEngine:
542
    """An asynchronous wrapper for :class:`LLMEngine`.
543

544
545
546
547
548
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
549
550

    Args:
551
        log_requests: Whether to log the requests.
552
553
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
554
555
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
556
    """
557

Antoni Baum's avatar
Antoni Baum committed
558
559
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

560
561
562
    def __init__(self,
                 *args,
                 log_requests: bool = True,
563
                 start_engine_loop: bool = True,
564
                 **kwargs) -> None:
565
        self.log_requests = log_requests
566
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
567

568
569
570
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
571
572
573
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

574
575
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
576
                weak_bind(self.process_request_outputs)
577

578
        self.background_loop: Optional[asyncio.Future] = None
579
580
581
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
582
        self._background_loop_unshielded: Optional[asyncio.Task] = None
583
        self.start_engine_loop = start_engine_loop
584
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
585

586
587
588
        # Lazy initialized fields
        self._request_tracker: RequestTracker

589
590
591
592
593
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

594
    @classmethod
595
596
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
597
598
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
599
600
601
602
603
604
605
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
606
607
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
608
        elif engine_config.device_config.device_type == "tpu":
609
610
611
612
613
614
615
            if distributed_executor_backend == "ray":
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
616
617
618
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
619
620
621
622
623
624
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
625
626
627
628
629
630
631
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
632
633
634
635
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
636
637
638
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
639
        elif distributed_executor_backend == "ray":
640
641
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
642
643
644
645
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
646
647
648
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
649
650
651
652
653
654
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
655
        engine_config: Optional[EngineConfig] = None,
656
657
658
659
660
661
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
662
663
        if engine_config is None:
            engine_config = engine_args.create_engine_config()
664
665
666

        executor_class = cls._get_executor_cls(engine_config)

667
668
669
        if executor_class.uses_ray:
            initialize_ray_cluster(engine_config.parallel_config)

670
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
671
        engine = cls(
672
673
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
674
675
676
677
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
678
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
679
        )
680
681
        return engine

682
683
    @property
    def is_running(self) -> bool:
684
        return (self.background_loop is not None
685
                and self._background_loop_unshielded is not None
686
687
688
689
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
690
691
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
692
693
694
695
696
697
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

698
    @property
699
700
701
702
703
704
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
705

706
707
708
709
710
711
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
712

713
714
715
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
716
    ) -> AnyTokenizer:
717
718
        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
719

720
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
721
        """Start the background loop."""
722
723
724
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
725
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
726
            raise RuntimeError("Background loop is already running.")
727
728
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
729
730

        self._background_loop_unshielded = asyncio.get_event_loop(
731
        ).create_task(self.run_engine_loop(weakref.ref(self)))
732
        self._background_loop_unshielded.add_done_callback(
733
            partial(_log_task_completion, error_callback=self._error_callback))
734
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
735

736
737
738
739
740
741
742
743
744
745
746
747
748
749
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

750
    async def engine_step(self, virtual_engine: int) -> bool:
751
752
753
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
754

755
756
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
757
758
759

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
760
            try:
761
                await self.engine.add_request_async(**new_request)
762
763
764
765
766
767
768
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
769

770
771
        if aborted_requests:
            await self._engine_abort(aborted_requests)
772

773
        request_outputs = await self.engine.step_async(virtual_engine)
774

Antoni Baum's avatar
Antoni Baum committed
775
        # Put the outputs into the corresponding streams.
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
791
        for request_output in request_outputs:
792
            self._request_tracker.process_request_output(
793
                request_output, verbose=self.log_requests)
794
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
795

796
        return all_finished
797

Antoni Baum's avatar
Antoni Baum committed
798
    async def _engine_abort(self, request_ids: Iterable[str]):
799
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
800

801
802
803
804
805
806
807
808
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
        engine: Optional["AsyncLLMEngine"] = engine_ref()
        if not engine:
            return

809
        pipeline_parallel_size = \
810
                engine.engine.parallel_config.pipeline_parallel_size
811
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
812
        while True:
813
            if not any(has_requests_in_progress):
814
                logger.debug("Waiting for new requests...")
815
816
817
818
819
820
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
821
822
823
824
825
826
827
828
829
830
831
832
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
833
                logger.debug("Got new requests!")
834
                requests_in_progress = [
835
                    asyncio.create_task(engine.engine_step(ve))
836
837
838
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
839
840
841
842

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
843
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
844
845
846
847
848
849
850
851
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
852
                    has_unfinished_requests = (
853
854
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
855
                            virtual_engine))
856
857
858
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
859
                                engine.engine_step(virtual_engine)))
860
861
862
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
863
864
865
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
866
                engine.set_errored(exc)
867
                raise
Antoni Baum's avatar
Antoni Baum committed
868
869
            await asyncio.sleep(0)

870
871
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
872
873
    @overload  # DEPRECATED
    def add_request(
874
875
        self,
        request_id: str,
876
877
        *,
        inputs: PromptType,
878
        params: Union[SamplingParams, PoolingParams],
879
880
881
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
882
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
883
        priority: int = 0,
884
885
886
887
888
889
890
891
892
893
894
895
896
897
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
898
        priority: int = 0,
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
916
        priority: int = 0,
917
918
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
919
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
920
921
922
923
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

924
        if not self.is_running:
925
926
927
928
929
930
931
932
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
933

934
935
936
937
938
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

939
        stream = self._request_tracker.add_request(
940
            request_id,
941
            verbose=self.log_requests,
942
            prompt=prompt,
943
            params=params,
944
            arrival_time=arrival_time or time.time(),
945
            lora_request=lora_request,
946
            trace_headers=trace_headers,
947
948
949
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
950

951
        return stream.generator()
952

953
    async def generate(
954
        self,
955
        prompt: PromptType,
956
957
        sampling_params: SamplingParams,
        request_id: str,
958
        lora_request: Optional[LoRARequest] = None,
959
        trace_headers: Optional[Mapping[str, str]] = None,
960
961
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
962
    ) -> AsyncGenerator[RequestOutput, None]:
963
964
965
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
966
967
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
968
969

        Args:
970
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
971
                for more details about the format of each input.
972
973
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
974
            lora_request: LoRA request to use for generation, if any.
975
            trace_headers: OpenTelemetry trace headers.
976
            prompt_adapter_request: Prompt Adapter request to use
977
                                            for generation, if any.
978
979
            priority: The priority of the request.
                Only applicable with priority scheduling.
980
981

        Yields:
982
983
            The output `RequestOutput` objects from the LLMEngine
            for the request.
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1027
        """
1028
        async for output in await self.add_request(
1029
                request_id,
1030
                prompt,
1031
                sampling_params,
1032
                lora_request=lora_request,
1033
                trace_headers=trace_headers,
1034
                prompt_adapter_request=prompt_adapter_request,
1035
                priority=priority,
1036
        ):
1037
            yield LLMEngine.validate_output(output, RequestOutput)
1038
1039
1040

    async def encode(
        self,
1041
        prompt: PromptType,
1042
1043
1044
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1045
        trace_headers: Optional[Mapping[str, str]] = None,
1046
        priority: int = 0,
1047
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
1048
1049
1050
1051
1052
1053
1054
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1055
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1056
                for more details about the format of each input.
1057
1058
1059
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1060
            trace_headers: OpenTelemetry trace headers.
1061
1062
            priority: The priority of the request.
                Only applicable with priority scheduling.
1063
1064

        Yields:
1065
            The output `EmbeddingRequestOutput` objects from the LLMEngine
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1109
        async for output in await self.add_request(
1110
                request_id,
1111
                prompt,
1112
                pooling_params,
1113
                lora_request=lora_request,
1114
                trace_headers=trace_headers,
1115
                priority=priority,
1116
        ):
1117
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1118

Antoni Baum's avatar
Antoni Baum committed
1119
1120
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1121

Antoni Baum's avatar
Antoni Baum committed
1122
1123
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1124

Antoni Baum's avatar
Antoni Baum committed
1125
1126
1127
        Args:
            request_id: The unique id of the request.
        """
1128
1129
1130
1131
1132
1133
1134
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1135
        return self._abort(request_id)
1136

Antoni Baum's avatar
Antoni Baum committed
1137
    def _abort(self, request_id: str) -> None:
1138
1139
1140
1141
1142
1143
1144
1145
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1146
        self._request_tracker.abort_request(request_id,
1147
                                            exception=asyncio.CancelledError,
1148
                                            verbose=self.log_requests)
1149

1150
1151
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1152
        return self.engine.get_model_config()
1153

1154
1155
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1156
        return self.engine.get_parallel_config()
1157

1158
1159
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1160
        return self.engine.get_decoding_config()
1161

1162
1163
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1164
        return self.engine.get_scheduler_config()
1165
1166
1167

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1168
        return self.engine.get_lora_config()
1169

1170
1171
1172
1173
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1174
        self.engine.do_log_stats()
1175

1176
    async def check_health(self) -> None:
1177
1178
1179
1180
1181
1182
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1183
        await self.engine.check_health_async()
1184
        logger.debug("Health check took %fs", time.perf_counter() - t)
1185
1186

    async def is_tracing_enabled(self) -> bool:
1187
        return self.engine.is_tracing_enabled()
1188
1189

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1190
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1191
1192

    def remove_logger(self, logger_name: str) -> None:
1193
        self.engine.remove_logger(logger_name=logger_name)
1194
1195

    async def start_profile(self) -> None:
1196
1197
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1198
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1199
1200
1201
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")
1202
1203

    async def stop_profile(self) -> None:
1204
1205
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1206
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1207
1208
1209
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")