async_llm_engine.py 42 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
6

7
import vllm.envs as envs
8
9
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
10
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm.engine.arg_utils import AsyncEngineArgs
12
from vllm.engine.async_timeout import asyncio_timeout
13
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
14
from vllm.engine.metrics_types import StatLoggerBase
15
from vllm.executor.executor_base import ExecutorAsyncBase
16
from vllm.executor.ray_utils import initialize_ray_cluster
17
from vllm.inputs import PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
from vllm.model_executor.layers.sampler import SamplerOutput
21
22
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
23
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.sampling_params import SamplingParams
25
from vllm.sequence import ExecuteModelRequest
26
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
27
from vllm.usage.usage_lib import UsageContext
28
29

logger = init_logger(__name__)
30
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
31

Antoni Baum's avatar
Antoni Baum committed
32

33
34
35
36
class AsyncEngineDeadError(RuntimeError):
    pass


37
38
39
40
41
42
43
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
44
45

    exception = None
46
    try:
47
48
49
50
51
52
53
54
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
55
56
57
58
59
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
60
            "Task finished unexpectedly. This should never happen! "
61
            "Please open an issue on Github. See stack trace above for the "
62
            "actual cause.") from e
63
64


65
66
67
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
68
class AsyncStream:
69
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
70
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
71

72
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
73
        self.request_id = request_id
74
        self._cancel = cancel
75
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
76
77
        self._finished = False

78
79
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
80
81
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
82

83
84
85
86
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
87
88
89
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
90
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
91
92
93
94
95

    @property
    def finished(self) -> bool:
        return self._finished

96
97
98
99
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
100
            while True:
101
                result = await self._queue.get()
102
                if self._is_raisable(result):
103
104
105
106
107
108
109
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
110

111
112
113
114
115
116
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
117

118
119
120
121
122
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
123
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
124
125
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
126
        self.new_requests_event = asyncio.Event()
127
128
129
130

    def __contains__(self, item):
        return item in self._request_streams

131
132
    def __len__(self) -> int:
        return len(self._request_streams)
133
134
135
136
137
138
139

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
140
            self.abort_request(request_id, exception=exc)
141
        else:
142
            # NB: tuple() used here because self.abort_request pops the stream
143
            # out of self._request_streams, so we can't iterate on it directly
144
145
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
146
147

    def process_request_output(self,
148
149
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
150
151
152
153
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
154
        finished = request_output.finished
155

156
157
158
159
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
160
161
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
162
        if stream is not None:
163
            stream.put(request_output)
164
165
166
167
168
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
169

170
171
    def process_exception(self,
                          request_id: str,
172
                          exception: BaseException,
173
174
175
176
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
177
            logger.info("Finished request %s.", request_id)
178
        self.abort_request(request_id, exception=exception)
179

180
181
182
183
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
184
185
186
187
188
189
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

190
191
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
192
193
194
195
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
196
197
198

        self.new_requests_event.set()

199
200
201
        if verbose:
            logger.info("Added request %s.", request_id)

202
203
        return stream

204
205
206
    def abort_request(self,
                      request_id: str,
                      *,
207
208
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
209
                      verbose: bool = False) -> None:
210
211
        """Abort a request during next background loop iteration."""
        if verbose:
212
            logger.info("Aborted request %s.", request_id)
213

214
        self._aborted_requests.put_nowait(request_id)
215

216
217
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
218
            stream.finish(exception=exception)
219

220
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
221
222
        """Get the new requests and finished requests to be
        sent to the engine."""
223
        new_requests: List[Dict] = []
224
225
        finished_requests: Set[str] = set()

226
227
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
228
229
230
231
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
232
233
            request_id = stream.request_id
            if request_id in finished_requests:
234
                # The request has already been aborted.
235
236
237
238
239
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
240
241

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
242

243
    async def wait_for_new_requests(self):
244
245
246
247
248
249
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
250

Antoni Baum's avatar
Antoni Baum committed
251
252
253
254

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

255
256
257
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

258
    async def step_async(
259
260
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
261
262
263
264
265
266
267
268
269
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
270
271
272
273
274
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
275
276
        allow_async_output_proc = cached_outputs.allow_async_output_proc

277
278
        ctx = self.scheduler_contexts[virtual_engine]

279
280
281
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

282
283
284
285
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
286

287
            # Schedule iteration
288
289
290
291
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

292
293
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
294
295

            # Maybe switch from async mode to sync mode
296
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
297
                self._process_model_outputs(ctx=ctx)
298

299
300
301
302
303
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
304
305
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
306
307
308

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
309

310
        if not scheduler_outputs.is_empty():
311
312
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
313
314
315
316
317
318
319
320

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

321
322
323
324
325
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
326
                virtual_engine=virtual_engine,
327
328
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
329
330
331
332
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
333
334

            if allow_async_output_proc:
335
336
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
337

338
            # Execute the model.
339
            outputs = await self.model_executor.execute_model_async(
340
                execute_model_req)
341

342
343
344
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
345
                self._update_cached_scheduler_output(virtual_engine, outputs)
346
        else:
347
348
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
349
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
350

351
352
353
354
355
356
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
357
            # Clear the cache if we have finished all the steps
358
359
360
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
361

362
363
364
365
366
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
                              is_last_step=True)
367

368
            if outputs and allow_async_output_proc:
369
                assert len(
370
                    outputs
371
372
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
373
                    outputs[0], seq_group_metadata_list,
374
                    scheduler_outputs.scheduled_seq_groups)
375
376

            if not allow_async_output_proc:
377
                self._process_model_outputs(ctx=ctx)
378
379

                # Log stats.
380
                self.do_log_stats(scheduler_outputs, outputs)
381
382
383
384
385

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
386
            # Multi-step case
387
            return ctx.request_outputs
388
389
390
391

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
392
                self._process_model_outputs(ctx=ctx)
393
            assert len(ctx.output_queue) == 0
394

395
        return ctx.request_outputs
396

397
398
399
400
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

401
    async def add_request_async(
402
403
404
405
406
407
408
409
        self,
        request_id: str,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
410
    ) -> None:
411
        """Async version of :meth:`add_request`."""
412
413
414
415
416
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
417

418
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
419
            inputs,
420
421
            request_id=request_id,
            lora_request=lora_request,
422
423
            prompt_adapter_request=prompt_adapter_request,
        )
424
        processed_inputs = self.input_processor(preprocessed_inputs)
425
426

        self._add_processed_request(
427
            request_id=request_id,
428
429
430
431
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
432
            prompt_adapter_request=prompt_adapter_request,
433
            trace_headers=trace_headers,
434
        )
435

436
    async def check_health_async(self) -> None:
437
438
        if self.tokenizer:
            self.tokenizer.check_health()
439
        self.model_executor.check_health()
440

441

442
class AsyncLLMEngine:
443
    """An asynchronous wrapper for :class:`LLMEngine`.
444

445
446
447
448
449
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
450
451
452
453
454

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
455
        log_requests: Whether to log the requests.
456
457
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
458
459
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
460
    """
461

Antoni Baum's avatar
Antoni Baum committed
462
463
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

464
465
466
467
    def __init__(self,
                 worker_use_ray: bool,
                 *args,
                 log_requests: bool = True,
468
                 start_engine_loop: bool = True,
469
                 **kwargs) -> None:
470
        self.worker_use_ray = worker_use_ray
471
        self.log_requests = log_requests
472
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
473

474
475
476
477
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
        #
478
        self.use_process_request_outputs_callback = True
479
480
481
482
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
                self.process_request_outputs

483
        self.background_loop: Optional[asyncio.Future] = None
484
485
486
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
487
        self._background_loop_unshielded: Optional[asyncio.Task] = None
488
        self.start_engine_loop = start_engine_loop
489
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
490

491
492
493
        # Lazy initialized fields
        self._request_tracker: RequestTracker

494
    @classmethod
495
496
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
497
498
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
499
500
501
502
503
504
505
506
507
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
508
509
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
510
        elif engine_config.device_config.device_type == "tpu":
511
512
513
514
515
516
517
518
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
519
520
521
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
522
523
524
525
526
527
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
528
529
530
531
532
533
534
535
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
536
537
538
539
540
            elif distributed_executor_backend == "mp":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
541
542
543
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
544
        elif distributed_executor_backend == "ray":
545
            initialize_ray_cluster(engine_config.parallel_config)
546
547
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
548
549
550
551
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
552
553
554
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        executor_class = cls._get_executor_cls(engine_config)

571
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
572
        engine = cls(
573
            executor_class.uses_ray,
574
575
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
576
577
578
579
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
580
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
581
        )
582
583
        return engine

584
585
    @property
    def is_running(self) -> bool:
586
        return (self.background_loop is not None
587
                and self._background_loop_unshielded is not None
588
589
590
591
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
592
593
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
594
595
596
597
598
599
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

600
601
602
603
604
    @property
    def limit_concurrency(self) -> Optional[int]:
        """Maximum number of concurrently running requests."""
        return None

605
606
607
608
609
610
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
611

612
613
614
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
615
    ) -> AnyTokenizer:
616
617
        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
618

619
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
620
        """Start the background loop."""
621
622
623
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
624
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
625
            raise RuntimeError("Background loop is already running.")
626
627
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
628
629
630
631

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
632
            partial(_log_task_completion, error_callback=self._error_callback))
633
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
634

635
636
637
638
639
640
641
642
643
644
645
646
647
648
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

649
    async def engine_step(self, virtual_engine: int) -> bool:
650
651
652
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
653

654
655
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
656
657
658

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
659
            try:
660
                await self.engine.add_request_async(**new_request)
661
662
663
664
665
666
667
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
668

669
670
        if aborted_requests:
            await self._engine_abort(aborted_requests)
671

672
        request_outputs = await self.engine.step_async(virtual_engine)
673

Antoni Baum's avatar
Antoni Baum committed
674
        # Put the outputs into the corresponding streams.
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
690
        for request_output in request_outputs:
691
            self._request_tracker.process_request_output(
692
                request_output, verbose=self.log_requests)
693
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
694

695
        return all_finished
696

Antoni Baum's avatar
Antoni Baum committed
697
    async def _engine_abort(self, request_ids: Iterable[str]):
698
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
699
700

    async def run_engine_loop(self):
701
        pipeline_parallel_size = \
702
703
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
704
        while True:
705
            if not any(has_requests_in_progress):
706
                logger.debug("Waiting for new requests...")
707
708
709
710
711
712
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
713
                await self.engine.stop_remote_worker_execution_loop_async()
714
                await self._request_tracker.wait_for_new_requests()
715
                logger.debug("Got new requests!")
716
717
718
719
720
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
721
722
723
724

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
725
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
726
727
728
729
730
731
732
733
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
734
735
736
                    has_unfinished_requests = (
                        self.engine.has_unfinished_requests_for_virtual_engine(
                            virtual_engine))
737
738
739
740
741
742
743
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
744
745
746
747
748
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
749
750
            await asyncio.sleep(0)

751
752
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
Antoni Baum's avatar
Antoni Baum committed
753
754
755
    async def add_request(
        self,
        request_id: str,
756
        inputs: PromptInputs,
757
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
758
        arrival_time: Optional[float] = None,
759
        lora_request: Optional[LoRARequest] = None,
760
        trace_headers: Optional[Mapping[str, str]] = None,
761
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
762
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
763
        if not self.is_running:
764
765
766
767
768
769
770
771
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
772

773
        stream = self._request_tracker.add_request(
774
            request_id,
775
            verbose=self.log_requests,
776
            inputs=inputs,
777
            params=params,
778
            arrival_time=arrival_time or time.time(),
779
            lora_request=lora_request,
780
            trace_headers=trace_headers,
781
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
782

783
        return stream.generator()
784

785
    async def generate(
786
        self,
787
        inputs: PromptInputs,
788
789
        sampling_params: SamplingParams,
        request_id: str,
790
        lora_request: Optional[LoRARequest] = None,
791
        trace_headers: Optional[Mapping[str, str]] = None,
792
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
793
    ) -> AsyncGenerator[RequestOutput, None]:
794
795
796
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
797
798
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
799
800

        Args:
801
802
803
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
804
805
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
806
            lora_request: LoRA request to use for generation, if any.
807
            trace_headers: OpenTelemetry trace headers.
808
809
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
810
811

        Yields:
812
813
            The output `RequestOutput` objects from the LLMEngine
            for the request.
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
857
        """
858
        async for output in await self.add_request(
859
                request_id,
860
                inputs,
861
                sampling_params,
862
                lora_request=lora_request,
863
                trace_headers=trace_headers,
864
                prompt_adapter_request=prompt_adapter_request,
865
        ):
866
            yield LLMEngine.validate_output(output, RequestOutput)
867
868
869

    async def encode(
        self,
870
        inputs: PromptInputs,
871
872
873
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
874
        trace_headers: Optional[Mapping[str, str]] = None,
875
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
876
877
878
879
880
881
882
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
883
884
885
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
886
887
888
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
889
            trace_headers: OpenTelemetry trace headers.
890
891

        Yields:
892
            The output `EmbeddingRequestOutput` objects from the LLMEngine
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
936
        async for output in await self.add_request(
937
                request_id,
938
                inputs,
939
                pooling_params,
940
                lora_request=lora_request,
941
                trace_headers=trace_headers,
942
        ):
943
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
944

Antoni Baum's avatar
Antoni Baum committed
945
946
    async def abort(self, request_id: str) -> None:
        """Abort a request.
947

Antoni Baum's avatar
Antoni Baum committed
948
949
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
950

Antoni Baum's avatar
Antoni Baum committed
951
952
953
        Args:
            request_id: The unique id of the request.
        """
954
955
956
957
958
959
960
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
961
        return self._abort(request_id)
962

Antoni Baum's avatar
Antoni Baum committed
963
    def _abort(self, request_id: str) -> None:
964
965
966
967
968
969
970
971
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
972
        self._request_tracker.abort_request(request_id,
973
                                            exception=asyncio.CancelledError,
974
                                            verbose=self.log_requests)
975

976
977
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
978
        return self.engine.get_model_config()
979

980
981
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
982
        return self.engine.get_parallel_config()
983

984
985
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
986
        return self.engine.get_decoding_config()
987

988
989
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
990
        return self.engine.get_scheduler_config()
991
992
993

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
994
        return self.engine.get_lora_config()
995

996
997
998
999
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1000
        self.engine.do_log_stats()
1001

1002
    async def check_health(self) -> None:
1003
1004
1005
1006
1007
1008
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1009
        await self.engine.check_health_async()
1010
        logger.debug("Health check took %fs", time.perf_counter() - t)
1011
1012

    async def is_tracing_enabled(self) -> bool:
1013
        return self.engine.is_tracing_enabled()
1014
1015

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1016
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1017
1018

    def remove_logger(self, logger_name: str) -> None:
1019
        self.engine.remove_logger(logger_name=logger_name)
1020
1021
1022
1023
1024
1025

    async def start_profile(self) -> None:
        self.engine.model_executor._run_workers("start_profile")

    async def stop_profile(self) -> None:
        self.engine.model_executor._run_workers("stop_profile")