async_llm_engine.py 43 KB
Newer Older
1
2
import asyncio
import time
3
import weakref
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
7
from weakref import ReferenceType
8

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
15
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
16
from vllm.engine.metrics_types import StatLoggerBase
17
from vllm.executor.executor_base import ExecutorAsyncBase
18
from vllm.executor.gpu_executor import GPUExecutorAsync
19
from vllm.executor.ray_utils import initialize_ray_cluster
20
from vllm.inputs import PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
23
from vllm.model_executor.layers.sampler import SamplerOutput
24
25
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
26
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.sampling_params import SamplingParams
28
from vllm.sequence import ExecuteModelRequest
29
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
30
from vllm.usage.usage_lib import UsageContext
31
from vllm.utils import weak_bind
32
33

logger = init_logger(__name__)
34
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
35

Antoni Baum's avatar
Antoni Baum committed
36

37
38
39
40
class AsyncEngineDeadError(RuntimeError):
    pass


41
42
43
44
45
46
47
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
48
49

    exception = None
50
    try:
51
52
53
54
55
56
57
58
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
59
60
61
62
63
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
64
            "Task finished unexpectedly. This should never happen! "
65
            "Please open an issue on Github. See stack trace above for the "
66
            "actual cause.") from e
67
68


69
70
71
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
72
class AsyncStream:
73
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
74
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
75

76
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
77
        self.request_id = request_id
78
        self._cancel = cancel
79
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
80
81
        self._finished = False

82
83
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
84
85
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
86

87
88
89
90
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
91
92
93
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
94
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
95
96
97
98
99

    @property
    def finished(self) -> bool:
        return self._finished

100
101
102
103
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
104
            while True:
105
                result = await self._queue.get()
106
                if self._is_raisable(result):
107
108
109
110
111
112
113
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
114

115
116
117
118
119
120
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
121

122
123
124
125
126
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
127
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
128
129
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
130
        self.new_requests_event = asyncio.Event()
131
132
133
134

    def __contains__(self, item):
        return item in self._request_streams

135
136
    def __len__(self) -> int:
        return len(self._request_streams)
137
138
139
140
141
142
143

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
144
            self.abort_request(request_id, exception=exc)
145
        else:
146
            # NB: tuple() used here because self.abort_request pops the stream
147
            # out of self._request_streams, so we can't iterate on it directly
148
149
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
150
151

    def process_request_output(self,
152
153
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
154
155
156
157
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
158
        finished = request_output.finished
159

160
161
162
163
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
164
165
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
166
        if stream is not None:
167
            stream.put(request_output)
168
169
170
171
172
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
173

174
175
    def process_exception(self,
                          request_id: str,
176
                          exception: BaseException,
177
178
179
180
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
181
            logger.info("Finished request %s.", request_id)
182
        self.abort_request(request_id, exception=exception)
183

184
185
186
187
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
188
189
190
191
192
193
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

194
195
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
196
197
198
199
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
200
201
202

        self.new_requests_event.set()

203
204
205
        if verbose:
            logger.info("Added request %s.", request_id)

206
207
        return stream

208
209
210
    def abort_request(self,
                      request_id: str,
                      *,
211
212
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
213
                      verbose: bool = False) -> None:
214
215
        """Abort a request during next background loop iteration."""
        if verbose:
216
            logger.info("Aborted request %s.", request_id)
217

218
        self._aborted_requests.put_nowait(request_id)
219

220
221
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
222
            stream.finish(exception=exception)
223

224
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
225
226
        """Get the new requests and finished requests to be
        sent to the engine."""
227
        new_requests: List[Dict] = []
228
229
        finished_requests: Set[str] = set()

230
231
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
232
233
234
235
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
236
237
            request_id = stream.request_id
            if request_id in finished_requests:
238
                # The request has already been aborted.
239
240
241
242
243
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
244
245

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
246

247
    async def wait_for_new_requests(self):
248
249
250
251
252
253
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
254

Antoni Baum's avatar
Antoni Baum committed
255
256
257
258

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

259
260
261
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

262
    async def step_async(
263
264
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
265
266
267
268
269
270
271
272
273
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
274
275
276
277
278
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
279
280
        allow_async_output_proc = cached_outputs.allow_async_output_proc

281
282
        ctx = self.scheduler_contexts[virtual_engine]

283
284
285
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

286
287
288
289
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
290

291
            # Schedule iteration
292
293
294
295
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

296
297
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
298
299

            # Maybe switch from async mode to sync mode
300
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
301
                self._process_model_outputs(ctx=ctx)
302

303
304
305
306
307
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
308
309
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
310
311
312

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
313

314
        if not scheduler_outputs.is_empty():
315
316
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
317
318
319
320
321
322
323
324

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

325
326
327
328
329
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
330
                virtual_engine=virtual_engine,
331
332
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
333
334
335
336
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
337
338

            if allow_async_output_proc:
339
340
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
341

342
            # Execute the model.
343
            outputs = await self.model_executor.execute_model_async(
344
                execute_model_req)
345

346
347
348
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
349
                self._update_cached_scheduler_output(virtual_engine, outputs)
350
        else:
351
352
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
353
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
354

355
356
357
358
359
360
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
361
            # Clear the cache if we have finished all the steps
362
363
364
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
365

366
367
368
369
370
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
                              is_last_step=True)
371

372
            if outputs and allow_async_output_proc:
373
                assert len(
374
                    outputs
375
376
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
377
                    outputs[0], seq_group_metadata_list,
378
                    scheduler_outputs.scheduled_seq_groups)
379
380

            if not allow_async_output_proc:
381
                self._process_model_outputs(ctx=ctx)
382
383

                # Log stats.
384
                self.do_log_stats(scheduler_outputs, outputs)
385
386
387
388
389

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
390
            # Multi-step case
391
            return ctx.request_outputs
392
393
394
395

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
396
                self._process_model_outputs(ctx=ctx)
397
            assert len(ctx.output_queue) == 0
398

399
        return ctx.request_outputs
400

401
402
403
404
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

405
    async def add_request_async(
406
407
408
409
410
411
412
413
        self,
        request_id: str,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
414
    ) -> None:
415
        """Async version of :meth:`add_request`."""
416
417
418
419
420
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
421

422
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
423
            inputs,
424
425
            request_id=request_id,
            lora_request=lora_request,
426
427
            prompt_adapter_request=prompt_adapter_request,
        )
428
        processed_inputs = self.input_processor(preprocessed_inputs)
429
430

        self._add_processed_request(
431
            request_id=request_id,
432
433
434
435
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
436
            prompt_adapter_request=prompt_adapter_request,
437
            trace_headers=trace_headers,
438
        )
439

440
    async def check_health_async(self) -> None:
441
442
        if self.tokenizer:
            self.tokenizer.check_health()
443
        self.model_executor.check_health()
444

445

446
class AsyncLLMEngine:
447
    """An asynchronous wrapper for :class:`LLMEngine`.
448

449
450
451
452
453
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
454
455

    Args:
456
        log_requests: Whether to log the requests.
457
458
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
459
460
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
461
    """
462

Antoni Baum's avatar
Antoni Baum committed
463
464
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

465
466
467
    def __init__(self,
                 *args,
                 log_requests: bool = True,
468
                 start_engine_loop: bool = True,
469
                 **kwargs) -> None:
470
        self.log_requests = log_requests
471
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
472

473
474
475
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
476
477
478
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

479
480
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
481
                weak_bind(self.process_request_outputs)
482

483
        self.background_loop: Optional[asyncio.Future] = None
484
485
486
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
487
        self._background_loop_unshielded: Optional[asyncio.Task] = None
488
        self.start_engine_loop = start_engine_loop
489
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
490

491
492
493
        # Lazy initialized fields
        self._request_tracker: RequestTracker

494
495
496
497
498
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

499
    @classmethod
500
501
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
502
503
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
504
505
506
507
508
509
510
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
511
512
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
513
        elif engine_config.device_config.device_type == "tpu":
514
515
516
517
518
519
520
            if distributed_executor_backend == "ray":
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
521
522
523
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
524
525
526
527
528
529
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
530
531
532
533
534
535
536
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
537
538
539
540
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
541
542
543
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
544
        elif distributed_executor_backend == "ray":
545
546
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
547
548
549
550
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
551
552
553
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
554
555
556
557
558
559
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
560
        engine_config: Optional[EngineConfig] = None,
561
562
563
564
565
566
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
567
568
        if engine_config is None:
            engine_config = engine_args.create_engine_config()
569
570
571

        executor_class = cls._get_executor_cls(engine_config)

572
573
574
        if executor_class.uses_ray:
            initialize_ray_cluster(engine_config.parallel_config)

575
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
576
        engine = cls(
577
578
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
579
580
581
582
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
583
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
584
        )
585
586
        return engine

587
588
    @property
    def is_running(self) -> bool:
589
        return (self.background_loop is not None
590
                and self._background_loop_unshielded is not None
591
592
593
594
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
595
596
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
597
598
599
600
601
602
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

603
604
605
606
607
    @property
    def limit_concurrency(self) -> Optional[int]:
        """Maximum number of concurrently running requests."""
        return None

608
609
610
611
612
613
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
614

615
616
617
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
618
    ) -> AnyTokenizer:
619
620
        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
621

622
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
623
        """Start the background loop."""
624
625
626
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
627
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
628
            raise RuntimeError("Background loop is already running.")
629
630
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
631
632

        self._background_loop_unshielded = asyncio.get_event_loop(
633
        ).create_task(self.run_engine_loop(weakref.ref(self)))
634
        self._background_loop_unshielded.add_done_callback(
635
            partial(_log_task_completion, error_callback=self._error_callback))
636
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
637

638
639
640
641
642
643
644
645
646
647
648
649
650
651
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

652
    async def engine_step(self, virtual_engine: int) -> bool:
653
654
655
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
656

657
658
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
659
660
661

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
662
            try:
663
                await self.engine.add_request_async(**new_request)
664
665
666
667
668
669
670
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
671

672
673
        if aborted_requests:
            await self._engine_abort(aborted_requests)
674

675
        request_outputs = await self.engine.step_async(virtual_engine)
676

Antoni Baum's avatar
Antoni Baum committed
677
        # Put the outputs into the corresponding streams.
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
693
        for request_output in request_outputs:
694
            self._request_tracker.process_request_output(
695
                request_output, verbose=self.log_requests)
696
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
697

698
        return all_finished
699

Antoni Baum's avatar
Antoni Baum committed
700
    async def _engine_abort(self, request_ids: Iterable[str]):
701
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
702

703
704
705
706
707
708
709
710
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
        engine: Optional["AsyncLLMEngine"] = engine_ref()
        if not engine:
            return

711
        pipeline_parallel_size = \
712
                engine.engine.parallel_config.pipeline_parallel_size
713
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
714
        while True:
715
            if not any(has_requests_in_progress):
716
                logger.debug("Waiting for new requests...")
717
718
719
720
721
722
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
723
724
725
726
727
728
729
730
731
732
733
734
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
735
                logger.debug("Got new requests!")
736
                requests_in_progress = [
737
                    asyncio.create_task(engine.engine_step(ve))
738
739
740
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
741
742
743
744

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
745
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
746
747
748
749
750
751
752
753
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
754
                    has_unfinished_requests = (
755
756
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
757
                            virtual_engine))
758
759
760
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
761
                                engine.engine_step(virtual_engine)))
762
763
764
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
765
766
767
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
768
                engine.set_errored(exc)
769
                raise
Antoni Baum's avatar
Antoni Baum committed
770
771
            await asyncio.sleep(0)

772
773
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
Antoni Baum's avatar
Antoni Baum committed
774
775
776
    async def add_request(
        self,
        request_id: str,
777
        inputs: PromptInputs,
778
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
779
        arrival_time: Optional[float] = None,
780
        lora_request: Optional[LoRARequest] = None,
781
        trace_headers: Optional[Mapping[str, str]] = None,
782
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
783
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
784
        if not self.is_running:
785
786
787
788
789
790
791
792
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
793

794
        stream = self._request_tracker.add_request(
795
            request_id,
796
            verbose=self.log_requests,
797
            inputs=inputs,
798
            params=params,
799
            arrival_time=arrival_time or time.time(),
800
            lora_request=lora_request,
801
            trace_headers=trace_headers,
802
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
803

804
        return stream.generator()
805

806
    async def generate(
807
        self,
808
        inputs: PromptInputs,
809
810
        sampling_params: SamplingParams,
        request_id: str,
811
        lora_request: Optional[LoRARequest] = None,
812
        trace_headers: Optional[Mapping[str, str]] = None,
813
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
814
    ) -> AsyncGenerator[RequestOutput, None]:
815
816
817
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
818
819
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
820
821

        Args:
822
823
824
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
825
826
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
827
            lora_request: LoRA request to use for generation, if any.
828
            trace_headers: OpenTelemetry trace headers.
829
            prompt_adapter_request: Prompt Adapter request to use
830
                                            for generation, if any.
831
832

        Yields:
833
834
            The output `RequestOutput` objects from the LLMEngine
            for the request.
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
878
        """
879
        async for output in await self.add_request(
880
                request_id,
881
                inputs,
882
                sampling_params,
883
                lora_request=lora_request,
884
                trace_headers=trace_headers,
885
                prompt_adapter_request=prompt_adapter_request,
886
        ):
887
            yield LLMEngine.validate_output(output, RequestOutput)
888
889
890

    async def encode(
        self,
891
        inputs: PromptInputs,
892
893
894
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
895
        trace_headers: Optional[Mapping[str, str]] = None,
896
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
897
898
899
900
901
902
903
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
904
905
906
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
907
908
909
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
910
            trace_headers: OpenTelemetry trace headers.
911
912

        Yields:
913
            The output `EmbeddingRequestOutput` objects from the LLMEngine
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
957
        async for output in await self.add_request(
958
                request_id,
959
                inputs,
960
                pooling_params,
961
                lora_request=lora_request,
962
                trace_headers=trace_headers,
963
        ):
964
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
965

Antoni Baum's avatar
Antoni Baum committed
966
967
    async def abort(self, request_id: str) -> None:
        """Abort a request.
968

Antoni Baum's avatar
Antoni Baum committed
969
970
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
971

Antoni Baum's avatar
Antoni Baum committed
972
973
974
        Args:
            request_id: The unique id of the request.
        """
975
976
977
978
979
980
981
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
982
        return self._abort(request_id)
983

Antoni Baum's avatar
Antoni Baum committed
984
    def _abort(self, request_id: str) -> None:
985
986
987
988
989
990
991
992
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
993
        self._request_tracker.abort_request(request_id,
994
                                            exception=asyncio.CancelledError,
995
                                            verbose=self.log_requests)
996

997
998
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
999
        return self.engine.get_model_config()
1000

1001
1002
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1003
        return self.engine.get_parallel_config()
1004

1005
1006
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1007
        return self.engine.get_decoding_config()
1008

1009
1010
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1011
        return self.engine.get_scheduler_config()
1012
1013
1014

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1015
        return self.engine.get_lora_config()
1016

1017
1018
1019
1020
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1021
        self.engine.do_log_stats()
1022

1023
    async def check_health(self) -> None:
1024
1025
1026
1027
1028
1029
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1030
        await self.engine.check_health_async()
1031
        logger.debug("Health check took %fs", time.perf_counter() - t)
1032
1033

    async def is_tracing_enabled(self) -> bool:
1034
        return self.engine.is_tracing_enabled()
1035
1036

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1037
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1038
1039

    def remove_logger(self, logger_name: str) -> None:
1040
        self.engine.remove_logger(logger_name=logger_name)
1041
1042

    async def start_profile(self) -> None:
1043
1044
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1045
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1046
1047
1048
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")
1049
1050

    async def stop_profile(self) -> None:
1051
1052
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1053
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1054
1055
1056
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")