async_llm_engine.py 42.5 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
6

7
import vllm.envs as envs
8
9
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
10
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.engine.arg_utils import AsyncEngineArgs
12
from vllm.engine.async_timeout import asyncio_timeout
13
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
14
from vllm.engine.metrics_types import StatLoggerBase
15
from vllm.executor.executor_base import ExecutorAsyncBase
16
from vllm.executor.gpu_executor import GPUExecutorAsync
17
from vllm.executor.ray_utils import initialize_ray_cluster
18
from vllm.inputs import PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.logger import init_logger
20
from vllm.lora.request import LoRARequest
21
from vllm.model_executor.layers.sampler import SamplerOutput
22
23
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
24
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.sampling_params import SamplingParams
26
from vllm.sequence import ExecuteModelRequest
27
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
28
from vllm.usage.usage_lib import UsageContext
29
30

logger = init_logger(__name__)
31
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
32

Antoni Baum's avatar
Antoni Baum committed
33

34
35
36
37
class AsyncEngineDeadError(RuntimeError):
    pass


38
39
40
41
42
43
44
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
45
46

    exception = None
47
    try:
48
49
50
51
52
53
54
55
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
56
57
58
59
60
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
61
            "Task finished unexpectedly. This should never happen! "
62
            "Please open an issue on Github. See stack trace above for the "
63
            "actual cause.") from e
64
65


66
67
68
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
69
class AsyncStream:
70
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
71
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
72

73
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
74
        self.request_id = request_id
75
        self._cancel = cancel
76
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
77
78
        self._finished = False

79
80
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
81
82
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
83

84
85
86
87
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
88
89
90
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
91
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
92
93
94
95
96

    @property
    def finished(self) -> bool:
        return self._finished

97
98
99
100
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
101
            while True:
102
                result = await self._queue.get()
103
                if self._is_raisable(result):
104
105
106
107
108
109
110
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
111

112
113
114
115
116
117
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
118

119
120
121
122
123
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
124
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
125
126
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
127
        self.new_requests_event = asyncio.Event()
128
129
130
131

    def __contains__(self, item):
        return item in self._request_streams

132
133
    def __len__(self) -> int:
        return len(self._request_streams)
134
135
136
137
138
139
140

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
141
            self.abort_request(request_id, exception=exc)
142
        else:
143
            # NB: tuple() used here because self.abort_request pops the stream
144
            # out of self._request_streams, so we can't iterate on it directly
145
146
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
147
148

    def process_request_output(self,
149
150
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
151
152
153
154
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
155
        finished = request_output.finished
156

157
158
159
160
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
161
162
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
163
        if stream is not None:
164
            stream.put(request_output)
165
166
167
168
169
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
170

171
172
    def process_exception(self,
                          request_id: str,
173
                          exception: BaseException,
174
175
176
177
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
178
            logger.info("Finished request %s.", request_id)
179
        self.abort_request(request_id, exception=exception)
180

181
182
183
184
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
185
186
187
188
189
190
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

191
192
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
193
194
195
196
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
197
198
199

        self.new_requests_event.set()

200
201
202
        if verbose:
            logger.info("Added request %s.", request_id)

203
204
        return stream

205
206
207
    def abort_request(self,
                      request_id: str,
                      *,
208
209
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
210
                      verbose: bool = False) -> None:
211
212
        """Abort a request during next background loop iteration."""
        if verbose:
213
            logger.info("Aborted request %s.", request_id)
214

215
        self._aborted_requests.put_nowait(request_id)
216

217
218
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
219
            stream.finish(exception=exception)
220

221
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
222
223
        """Get the new requests and finished requests to be
        sent to the engine."""
224
        new_requests: List[Dict] = []
225
226
        finished_requests: Set[str] = set()

227
228
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
229
230
231
232
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
233
234
            request_id = stream.request_id
            if request_id in finished_requests:
235
                # The request has already been aborted.
236
237
238
239
240
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
241
242

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
243

244
    async def wait_for_new_requests(self):
245
246
247
248
249
250
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
251

Antoni Baum's avatar
Antoni Baum committed
252
253
254
255

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

256
257
258
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

259
    async def step_async(
260
261
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
262
263
264
265
266
267
268
269
270
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
271
272
273
274
275
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
276
277
        allow_async_output_proc = cached_outputs.allow_async_output_proc

278
279
        ctx = self.scheduler_contexts[virtual_engine]

280
281
282
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

283
284
285
286
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
287

288
            # Schedule iteration
289
290
291
292
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

293
294
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
295
296

            # Maybe switch from async mode to sync mode
297
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
298
                self._process_model_outputs(ctx=ctx)
299

300
301
302
303
304
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
305
306
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
307
308
309

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
310

311
        if not scheduler_outputs.is_empty():
312
313
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
314
315
316
317
318
319
320
321

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

322
323
324
325
326
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
327
                virtual_engine=virtual_engine,
328
329
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
330
331
332
333
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
334
335

            if allow_async_output_proc:
336
337
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
338

339
            # Execute the model.
340
            outputs = await self.model_executor.execute_model_async(
341
                execute_model_req)
342

343
344
345
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
346
                self._update_cached_scheduler_output(virtual_engine, outputs)
347
        else:
348
349
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
350
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
351

352
353
354
355
356
357
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
358
            # Clear the cache if we have finished all the steps
359
360
361
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
362

363
364
365
366
367
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
                              is_last_step=True)
368

369
            if outputs and allow_async_output_proc:
370
                assert len(
371
                    outputs
372
373
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
374
                    outputs[0], seq_group_metadata_list,
375
                    scheduler_outputs.scheduled_seq_groups)
376
377

            if not allow_async_output_proc:
378
                self._process_model_outputs(ctx=ctx)
379
380

                # Log stats.
381
                self.do_log_stats(scheduler_outputs, outputs)
382
383
384
385
386

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
387
            # Multi-step case
388
            return ctx.request_outputs
389
390
391
392

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
393
                self._process_model_outputs(ctx=ctx)
394
            assert len(ctx.output_queue) == 0
395

396
        return ctx.request_outputs
397

398
399
400
401
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

402
    async def add_request_async(
403
404
405
406
407
408
409
410
        self,
        request_id: str,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
411
    ) -> None:
412
        """Async version of :meth:`add_request`."""
413
414
415
416
417
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
418

419
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
420
            inputs,
421
422
            request_id=request_id,
            lora_request=lora_request,
423
424
            prompt_adapter_request=prompt_adapter_request,
        )
425
        processed_inputs = self.input_processor(preprocessed_inputs)
426
427

        self._add_processed_request(
428
            request_id=request_id,
429
430
431
432
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
433
            prompt_adapter_request=prompt_adapter_request,
434
            trace_headers=trace_headers,
435
        )
436

437
    async def check_health_async(self) -> None:
438
439
        if self.tokenizer:
            self.tokenizer.check_health()
440
        self.model_executor.check_health()
441

442

443
class AsyncLLMEngine:
444
    """An asynchronous wrapper for :class:`LLMEngine`.
445

446
447
448
449
450
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
451
452
453
454
455

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
456
        log_requests: Whether to log the requests.
457
458
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
459
460
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
461
    """
462

Antoni Baum's avatar
Antoni Baum committed
463
464
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

465
466
467
468
    def __init__(self,
                 worker_use_ray: bool,
                 *args,
                 log_requests: bool = True,
469
                 start_engine_loop: bool = True,
470
                 **kwargs) -> None:
471
        self.worker_use_ray = worker_use_ray
472
        self.log_requests = log_requests
473
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
474

475
476
477
478
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
        #
479
        self.use_process_request_outputs_callback = True
480
481
482
483
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
                self.process_request_outputs

484
        self.background_loop: Optional[asyncio.Future] = None
485
486
487
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
488
        self._background_loop_unshielded: Optional[asyncio.Task] = None
489
        self.start_engine_loop = start_engine_loop
490
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
491

492
493
494
        # Lazy initialized fields
        self._request_tracker: RequestTracker

495
    @classmethod
496
497
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
498
499
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
500
501
502
503
504
505
506
507
508
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
509
510
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
511
        elif engine_config.device_config.device_type == "tpu":
512
513
514
515
516
517
518
519
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
520
521
522
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
523
524
525
526
527
528
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
529
530
531
532
533
534
535
536
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
537
538
539
540
541
            elif distributed_executor_backend == "mp":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
542
543
544
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
545
        elif distributed_executor_backend == "ray":
546
            initialize_ray_cluster(engine_config.parallel_config)
547
548
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
549
550
551
552
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
553
554
555
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        executor_class = cls._get_executor_cls(engine_config)

572
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
573
        engine = cls(
574
            executor_class.uses_ray,
575
576
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
577
578
579
580
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
581
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
582
        )
583
584
        return engine

585
586
    @property
    def is_running(self) -> bool:
587
        return (self.background_loop is not None
588
                and self._background_loop_unshielded is not None
589
590
591
592
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
593
594
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
595
596
597
598
599
600
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

601
602
603
604
605
    @property
    def limit_concurrency(self) -> Optional[int]:
        """Maximum number of concurrently running requests."""
        return None

606
607
608
609
610
611
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
612

613
614
615
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
616
    ) -> AnyTokenizer:
617
618
        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
619

620
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
621
        """Start the background loop."""
622
623
624
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
625
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
626
            raise RuntimeError("Background loop is already running.")
627
628
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
629
630
631
632

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
633
            partial(_log_task_completion, error_callback=self._error_callback))
634
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
635

636
637
638
639
640
641
642
643
644
645
646
647
648
649
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

650
    async def engine_step(self, virtual_engine: int) -> bool:
651
652
653
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
654

655
656
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
657
658
659

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
660
            try:
661
                await self.engine.add_request_async(**new_request)
662
663
664
665
666
667
668
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
669

670
671
        if aborted_requests:
            await self._engine_abort(aborted_requests)
672

673
        request_outputs = await self.engine.step_async(virtual_engine)
674

Antoni Baum's avatar
Antoni Baum committed
675
        # Put the outputs into the corresponding streams.
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
691
        for request_output in request_outputs:
692
            self._request_tracker.process_request_output(
693
                request_output, verbose=self.log_requests)
694
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
695

696
        return all_finished
697

Antoni Baum's avatar
Antoni Baum committed
698
    async def _engine_abort(self, request_ids: Iterable[str]):
699
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
700
701

    async def run_engine_loop(self):
702
        pipeline_parallel_size = \
703
704
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
705
        while True:
706
            if not any(has_requests_in_progress):
707
                logger.debug("Waiting for new requests...")
708
709
710
711
712
713
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
714
                await self.engine.stop_remote_worker_execution_loop_async()
715
                await self._request_tracker.wait_for_new_requests()
716
                logger.debug("Got new requests!")
717
718
719
720
721
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
722
723
724
725

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
726
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
727
728
729
730
731
732
733
734
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
735
736
737
                    has_unfinished_requests = (
                        self.engine.has_unfinished_requests_for_virtual_engine(
                            virtual_engine))
738
739
740
741
742
743
744
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
745
746
747
748
749
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
750
751
            await asyncio.sleep(0)

752
753
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
Antoni Baum's avatar
Antoni Baum committed
754
755
756
    async def add_request(
        self,
        request_id: str,
757
        inputs: PromptInputs,
758
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
759
        arrival_time: Optional[float] = None,
760
        lora_request: Optional[LoRARequest] = None,
761
        trace_headers: Optional[Mapping[str, str]] = None,
762
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
763
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
764
        if not self.is_running:
765
766
767
768
769
770
771
772
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
773

774
        stream = self._request_tracker.add_request(
775
            request_id,
776
            verbose=self.log_requests,
777
            inputs=inputs,
778
            params=params,
779
            arrival_time=arrival_time or time.time(),
780
            lora_request=lora_request,
781
            trace_headers=trace_headers,
782
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
783

784
        return stream.generator()
785

786
    async def generate(
787
        self,
788
        inputs: PromptInputs,
789
790
        sampling_params: SamplingParams,
        request_id: str,
791
        lora_request: Optional[LoRARequest] = None,
792
        trace_headers: Optional[Mapping[str, str]] = None,
793
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
794
    ) -> AsyncGenerator[RequestOutput, None]:
795
796
797
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
798
799
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
800
801

        Args:
802
803
804
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
805
806
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
807
            lora_request: LoRA request to use for generation, if any.
808
            trace_headers: OpenTelemetry trace headers.
809
810
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
811
812

        Yields:
813
814
            The output `RequestOutput` objects from the LLMEngine
            for the request.
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
858
        """
859
        async for output in await self.add_request(
860
                request_id,
861
                inputs,
862
                sampling_params,
863
                lora_request=lora_request,
864
                trace_headers=trace_headers,
865
                prompt_adapter_request=prompt_adapter_request,
866
        ):
867
            yield LLMEngine.validate_output(output, RequestOutput)
868
869
870

    async def encode(
        self,
871
        inputs: PromptInputs,
872
873
874
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
875
        trace_headers: Optional[Mapping[str, str]] = None,
876
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
877
878
879
880
881
882
883
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
884
885
886
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
887
888
889
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
890
            trace_headers: OpenTelemetry trace headers.
891
892

        Yields:
893
            The output `EmbeddingRequestOutput` objects from the LLMEngine
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
937
        async for output in await self.add_request(
938
                request_id,
939
                inputs,
940
                pooling_params,
941
                lora_request=lora_request,
942
                trace_headers=trace_headers,
943
        ):
944
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
945

Antoni Baum's avatar
Antoni Baum committed
946
947
    async def abort(self, request_id: str) -> None:
        """Abort a request.
948

Antoni Baum's avatar
Antoni Baum committed
949
950
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
951

Antoni Baum's avatar
Antoni Baum committed
952
953
954
        Args:
            request_id: The unique id of the request.
        """
955
956
957
958
959
960
961
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
962
        return self._abort(request_id)
963

Antoni Baum's avatar
Antoni Baum committed
964
    def _abort(self, request_id: str) -> None:
965
966
967
968
969
970
971
972
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
973
        self._request_tracker.abort_request(request_id,
974
                                            exception=asyncio.CancelledError,
975
                                            verbose=self.log_requests)
976

977
978
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
979
        return self.engine.get_model_config()
980

981
982
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
983
        return self.engine.get_parallel_config()
984

985
986
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
987
        return self.engine.get_decoding_config()
988

989
990
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
991
        return self.engine.get_scheduler_config()
992
993
994

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
995
        return self.engine.get_lora_config()
996

997
998
999
1000
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1001
        self.engine.do_log_stats()
1002

1003
    async def check_health(self) -> None:
1004
1005
1006
1007
1008
1009
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1010
        await self.engine.check_health_async()
1011
        logger.debug("Health check took %fs", time.perf_counter() - t)
1012
1013

    async def is_tracing_enabled(self) -> bool:
1014
        return self.engine.is_tracing_enabled()
1015
1016

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1017
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1018
1019

    def remove_logger(self, logger_name: str) -> None:
1020
        self.engine.remove_logger(logger_name=logger_name)
1021
1022

    async def start_profile(self) -> None:
1023
1024
1025
1026
1027
1028
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
        if type(self.engine.model_executor) == GPUExecutorAsync:
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")
1029
1030

    async def stop_profile(self) -> None:
1031
1032
1033
1034
1035
1036
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
        if type(self.engine.model_executor) == GPUExecutorAsync:
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")