async_llm_engine.py 41.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import time
6
import weakref
Antoni Baum's avatar
Antoni Baum committed
7
from functools import partial
8
9
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
10
from weakref import ReferenceType
11

12
import vllm.envs as envs
13
14
15
from vllm.config import (DecodingConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig, VllmConfig)
from vllm.config.lora import LoRAConfig
16
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.engine.arg_utils import AsyncEngineArgs
18
from vllm.engine.async_timeout import asyncio_timeout
19
from vllm.engine.llm_engine import LLMEngine
20
from vllm.engine.metrics_types import StatLoggerBase
21
from vllm.engine.protocol import EngineClient
22
from vllm.executor.executor_base import ExecutorBase
23
from vllm.inputs import PromptType
24
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.logger import init_logger
26
from vllm.lora.request import LoRARequest
27
from vllm.model_executor.layers.sampler import SamplerOutput
28
from vllm.outputs import PoolingRequestOutput, RequestOutput
29
from vllm.pooling_params import PoolingParams
30
from vllm.sampling_params import SamplingParams
31
from vllm.sequence import ExecuteModelRequest
32
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
33
from vllm.usage.usage_lib import UsageContext
34
from vllm.utils import Device, deprecate_kwargs, weak_bind
35
36

logger = init_logger(__name__)
37
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
38

Antoni Baum's avatar
Antoni Baum committed
39

40
41
42
43
class AsyncEngineDeadError(RuntimeError):
    pass


44
45
46
47
48
49
50
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
51
52

    exception = None
53
    try:
54
55
56
57
58
59
60
61
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
62
63
64
65
66
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
67
            "Task finished unexpectedly. This should never happen! "
68
            "Please open an issue on GitHub. See stack trace above for the "
69
            "actual cause.") from e
70
71


72
73
74
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
75
class AsyncStream:
76
77
    """A stream of RequestOutputs for a request that can be iterated over
    asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
78

79
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
80
        self.request_id = request_id
81
        self._cancel = cancel
82
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
83
84
        self._finished = False

85
    def put(self, item: Union[RequestOutput, Exception]) -> None:
86
87
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
88

89
90
91
92
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
93
94
95
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
96
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
97
98
99
100
101

    @property
    def finished(self) -> bool:
        return self._finished

102
    async def generator(self) -> AsyncGenerator[RequestOutput, None]:
103
        try:
104
            while True:
105
                result = await self._queue.get()
106
                if self._is_raisable(result):
107
108
109
110
111
112
113
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
114

115
116
117
118
119
120
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
121

122
123
124
125
126
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
127
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
128
129
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
130
        self.new_requests_event = asyncio.Event()
131
132
133
134

    def __contains__(self, item):
        return item in self._request_streams

135
136
    def __len__(self) -> int:
        return len(self._request_streams)
137
138
139
140
141
142
143

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
144
            self.abort_request(request_id, exception=exc)
145
        else:
146
            # NB: tuple() used here because self.abort_request pops the stream
147
            # out of self._request_streams, so we can't iterate on it directly
148
149
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
150
151

    def process_request_output(self,
152
                               request_output: RequestOutput,
153
154
155
156
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
157
        finished = request_output.finished
158

159
160
161
162
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
163
164
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
165
        if stream is not None:
166
            stream.put(request_output)
167
168
169
170
171
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
172

173
174
    def process_exception(self,
                          request_id: str,
175
                          exception: BaseException,
176
177
178
179
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
180
            logger.info("Finished request %s.", request_id)
181
        self.abort_request(request_id, exception=exception)
182

183
184
185
186
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
187
188
189
190
191
192
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

193
194
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
195
196
197
198
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
199
200
201

        self.new_requests_event.set()

202
203
204
        if verbose:
            logger.info("Added request %s.", request_id)

205
206
        return stream

207
208
209
    def abort_request(self,
                      request_id: str,
                      *,
210
211
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
212
                      verbose: bool = False) -> None:
213
214
        """Abort a request during next background loop iteration."""
        if verbose:
215
            logger.info("Aborted request %s.", request_id)
216

217
        self._aborted_requests.put_nowait(request_id)
218

219
220
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
221
            stream.finish(exception=exception)
222

223
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
224
225
        """Get the new requests and finished requests to be
        sent to the engine."""
226
        new_requests: List[Dict] = []
227
228
        finished_requests: Set[str] = set()

229
230
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
231
232
233
234
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
235
236
            request_id = stream.request_id
            if request_id in finished_requests:
237
                # The request has already been aborted.
238
239
240
241
242
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
243
244

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
245

246
    async def wait_for_new_requests(self):
247
248
249
250
251
252
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
253

Antoni Baum's avatar
Antoni Baum committed
254
255
256
257

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

258
259
260
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

261
    async def step_async(self, virtual_engine: int) -> List[RequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
262
263
264
265
266
267
268
269
270
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
271
272
273
274
275
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
276
277
        allow_async_output_proc = cached_outputs.allow_async_output_proc

278
279
        ctx = self.scheduler_contexts[virtual_engine]

280
281
282
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

283
284
285
286
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
287

288
            # Schedule iteration
289
290
291
292
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

293
294
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
295

296
297
298
299
300
            if not scheduler_outputs.is_empty():
                # this will cause mamba_cache/minimax_cache failed
                # to release finished_requests_ids of the last steps
                finished_requests_ids = self.scheduler[
                    virtual_engine].get_and_reset_finished_requests_ids()
301

302
            # Maybe switch from async mode to sync mode
303
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
304
                self._process_model_outputs(ctx=ctx)
305

306
307
        else:
            finished_requests_ids = list()
308
309
310

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
311

312
        if not scheduler_outputs.is_empty():
313
314
315
316
317
318
319
320

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

321
322
323
324
325
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
326
                virtual_engine=virtual_engine,
327
328
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
329
330
331
332
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
333
334

            if allow_async_output_proc:
335
336
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
337

338
            # Execute the model.
339
            outputs = await self.model_executor.execute_model_async(
340
                execute_model_req)
341

342
        else:
343
344
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
345
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
346

347
        if not self._has_remaining_steps(seq_group_metadata_list):
348
            # is_first_step_output is True only when the num_steps of all
349
            # the sequences are 1.
350
351
352
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

353
354
355
356
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
357
358
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
359

360
            if outputs and allow_async_output_proc:
361
                assert len(
362
                    outputs
363
364
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
365
                    outputs[0], seq_group_metadata_list,
366
                    scheduler_outputs.scheduled_seq_groups)
367
368

            if not allow_async_output_proc:
369
                self._process_model_outputs(ctx=ctx)
370
371

                # Log stats.
372
                self.do_log_stats(scheduler_outputs, outputs)
373
374
375
376
377

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
378
            # Multi-step case
379
            return ctx.request_outputs
380
381
382
383

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
384
                self._process_model_outputs(ctx=ctx)
385
            assert len(ctx.output_queue) == 0
386

387
        return ctx.request_outputs
388

389
390
391
392
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

393
394
395
396
397
398
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

399
400
401
402
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
403
        params: SamplingParams,
404
405
406
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
407
        priority: int = 0,
408
        data_parallel_rank: Optional[int] = None,
409
        tokenization_kwargs: Optional[dict[str, Any]] = None,
410
    ) -> None:
411
412
413
414
        """
        Async version of
        [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request].
        """
415
416
417
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
418
419
420
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
421
422
        if arrival_time is None:
            arrival_time = time.time()
423

424
425
426
427
        if data_parallel_rank is not None:
            raise ValueError("Targeting data_parallel_rank only supported "
                             "in v1 client.")

428
429
430
431
432
433
434
435
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            # We use the -2 dimension (instead of 0) in case a batched input
            # of batch size 1 is passed in.
            prompt["prompt_token_ids"] = [0
                                          ] * prompt["prompt_embeds"].shape[-2]

436
        processed_inputs = await self.input_preprocessor.preprocess_async(
437
            prompt,
438
            lora_request=lora_request,
439
            tokenization_kwargs=tokenization_kwargs,
440
        )
441
442

        self._add_processed_request(
443
            request_id=request_id,
444
445
446
447
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
448
            trace_headers=trace_headers,
449
            priority=priority,
450
        )
451

452
453
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
454

455
456
457
458
459
460
461
    async def collective_rpc_async(self,
                                   method: str,
                                   timeout: Optional[float] = None,
                                   args: tuple = (),
                                   kwargs: Optional[dict] = None):
        raise NotImplementedError

462

463
class AsyncLLMEngine(EngineClient):
464
    """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
465

466
467
468
469
470
471
    This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to
    make it asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked
    by the generate method when there are requests in the waiting queue. The
    generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine]
    to the caller.
472
473

    Args:
474
        log_requests: Whether to log the requests.
475
476
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
477
478
        *args: Arguments for [`LLMEngine`][vllm.LLMEngine].
        **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine].
479
    """
480

Antoni Baum's avatar
Antoni Baum committed
481
482
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

483
    def __init__(self,
484
                 *args: Any,
485
                 log_requests: bool = True,
486
                 start_engine_loop: bool = True,
487
                 **kwargs: Any) -> None:
488
489
490
491
492
493
494
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

495
        self.log_requests = log_requests
496
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
497

498
499
500
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
501
502
503
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

504
505
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
506
                weak_bind(self.process_request_outputs)
507

508
        self.background_loop: Optional[asyncio.Future] = None
509
510
511
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
512
        self._background_loop_unshielded: Optional[asyncio.Task] = None
513
        self.start_engine_loop = start_engine_loop
514
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
515

516
517
518
        # Lazy initialized fields
        self._request_tracker: RequestTracker

519
520
521
522
523
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

524
    @classmethod
525
526
527
    def _get_executor_cls(cls,
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
        return LLMEngine._get_executor_cls(engine_config)
528

529
    @classmethod
530
531
532
533
534
    @deprecate_kwargs(
        "disable_log_requests",
        additional_message=("This argument will have no effect. "
                            "Use `enable_log_requests` instead."),
    )
535
    def from_vllm_config(
536
537
538
539
540
541
542
543
            cls,
            vllm_config: VllmConfig,
            start_engine_loop: bool = True,
            usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
            stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
            enable_log_requests: bool = False,
            disable_log_stats: bool = False,
            disable_log_requests: bool = True,  # Deprecated, will be removed
544
545
546
547
548
549
550
    ) -> "AsyncLLMEngine":
        """Create an AsyncLLMEngine from the EngineArgs."""

        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            start_engine_loop=start_engine_loop,
551
            log_requests=enable_log_requests,
552
553
554
555
556
            log_stats=not disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

557
558
559
560
561
562
563
564
565
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
566
567
568
569
570
571
572
573
574
575

        vllm_config = engine_args.create_engine_config(usage_context)

        async_engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
            async_engine_cls = V1AsyncLLMEngine

        return async_engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
576
577
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
578
            stat_loggers=stat_loggers,
579
            disable_log_stats=engine_args.disable_log_stats,
580
            enable_log_requests=engine_args.enable_log_requests,
yhu422's avatar
yhu422 committed
581
        )
582

583
584
    @property
    def is_running(self) -> bool:
585
        return (self.background_loop is not None
586
                and self._background_loop_unshielded is not None
587
588
589
590
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
591
592
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
593
594
595
596
597
598
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

599
    @property
600
601
602
603
604
605
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
606

607
608
609
610
611
612
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
613

614
615
616
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

617
618
619
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
620
    ) -> AnyTokenizer:
621
        return await self.engine.get_tokenizer_async(lora_request)
622

623
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
624
        """Start the background loop."""
625
626
627
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
628
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
629
            raise RuntimeError("Background loop is already running.")
630
631
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
632
633

        self._background_loop_unshielded = asyncio.get_event_loop(
634
        ).create_task(self.run_engine_loop(weakref.ref(self)))
635
        self._background_loop_unshielded.add_done_callback(
636
            partial(_log_task_completion, error_callback=self._error_callback))
637
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
638

639
640
641
642
643
644
645
646
647
648
649
650
651
652
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

653
    async def engine_step(self, virtual_engine: int) -> bool:
654
655
656
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
657

658
659
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
660
661
662

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
663
            try:
664
                await self.engine.add_request_async(**new_request)
665
666
667
668
669
670
671
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
672

673
674
        if aborted_requests:
            await self._engine_abort(aborted_requests)
675

676
        request_outputs = await self.engine.step_async(virtual_engine)
677

Antoni Baum's avatar
Antoni Baum committed
678
        # Put the outputs into the corresponding streams.
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
694
        for request_output in request_outputs:
695
            self._request_tracker.process_request_output(
696
                request_output, verbose=self.log_requests)
697
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
698

699
        return all_finished
700

Antoni Baum's avatar
Antoni Baum committed
701
    async def _engine_abort(self, request_ids: Iterable[str]):
702
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
703

704
705
706
707
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
708
        engine: Optional[AsyncLLMEngine] = engine_ref()
709
710
711
        if not engine:
            return

712
        pipeline_parallel_size = \
713
                engine.engine.parallel_config.pipeline_parallel_size
714
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
715
        while True:
716
            if not any(has_requests_in_progress):
717
                logger.debug("Waiting for new requests...")
718
719
720
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
721
                # time out, and unblocks the RPC thread in the workers so that
722
723
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
724
725
726
727
728
729
730
731
732
733
734
735
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
736
                logger.debug("Got new requests!")
737
                requests_in_progress = [
738
                    asyncio.create_task(engine.engine_step(ve))
739
740
741
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
742
743
744
745

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
746
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
747
748
749
750
751
752
753
754
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
755
                    has_unfinished_requests = (
756
757
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
758
                            virtual_engine))
759
760
761
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
762
                                engine.engine_step(virtual_engine)))
763
764
765
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
766
767
768
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
769
                engine.set_errored(exc)
770
                raise
Antoni Baum's avatar
Antoni Baum committed
771
772
            await asyncio.sleep(0)

773
    async def add_request(
774
775
776
        self,
        request_id: str,
        prompt: PromptType,
777
        params: SamplingParams,
778
779
780
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
781
        priority: int = 0,
782
        data_parallel_rank: Optional[int] = None,
783
        tokenization_kwargs: Optional[dict[str, Any]] = None,
784
    ) -> AsyncGenerator[RequestOutput, None]:
785
        if not self.is_running:
786
787
788
789
790
791
792
793
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
794

795
796
797
798
799
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

800
        stream = self._request_tracker.add_request(
801
            request_id,
802
            verbose=self.log_requests,
803
            prompt=prompt,
804
            params=params,
805
            arrival_time=arrival_time or time.time(),
806
            lora_request=lora_request,
807
            trace_headers=trace_headers,
808
            priority=priority,
809
            data_parallel_rank=data_parallel_rank,
810
            tokenization_kwargs=tokenization_kwargs,
811
        )
Antoni Baum's avatar
Antoni Baum committed
812

813
        return stream.generator()
814

815
    async def generate(
816
        self,
817
        prompt: PromptType,
818
819
        sampling_params: SamplingParams,
        request_id: str,
820
        lora_request: Optional[LoRARequest] = None,
821
        trace_headers: Optional[Mapping[str, str]] = None,
822
        priority: int = 0,
823
        data_parallel_rank: Optional[int] = None,
824
    ) -> AsyncGenerator[RequestOutput, None]:
825
826
827
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
828
829
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
830
831

        Args:
832
833
834
            prompt: The prompt to the LLM. See
                [`PromptType`][vllm.inputs.PromptType] for more details about
                the format of each input.
835
836
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
837
            lora_request: LoRA request to use for generation, if any.
838
            trace_headers: OpenTelemetry trace headers.
839
840
            priority: The priority of the request.
                Only applicable with priority scheduling.
841
842
            data_parallel_rank: The (global) data parallel rank that must
                handle this request. Only applicable if DP is enabled.
843
        Yields:
844
845
            The output `RequestOutput` objects from the LLMEngine
            for the request.
846
847
848
849

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
850
              [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step]
851
852
853
854
855
856
857
858
859
860
861
862
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
863
            >>> # note that engine_args here is AsyncEngineArgs instance
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
890
        """
891
892
893
894
895
896
897
898
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
899
                    data_parallel_rank=data_parallel_rank,
900
901
902
903
904
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
905

906
    def encode(
907
        self,
908
        prompt: PromptType,
909
910
911
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
912
        trace_headers: Optional[Mapping[str, str]] = None,
913
        priority: int = 0,
914
        tokenization_kwargs: Optional[dict[str, Any]] = None,
915
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
916
917
        raise NotImplementedError(
            "Pooling models are not supported in vLLM V0")
918

919
    async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
Antoni Baum's avatar
Antoni Baum committed
920
        """Abort a request.
921

Antoni Baum's avatar
Antoni Baum committed
922
923
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
924

Antoni Baum's avatar
Antoni Baum committed
925
926
927
        Args:
            request_id: The unique id of the request.
        """
928
929
930
        if not isinstance(request_id, str):
            raise RuntimeError("Only single-request abort supported in"
                               " deprecated V0")
931
932
933
934
935
936
937
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
938
        return self._abort(request_id)
939

Antoni Baum's avatar
Antoni Baum committed
940
    def _abort(self, request_id: str) -> None:
941
942
943
944
945
946
947
948
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
949
        self._request_tracker.abort_request(request_id,
950
                                            exception=asyncio.CancelledError,
951
                                            verbose=self.log_requests)
952

953
954
955
956
    async def get_vllm_config(self) -> VllmConfig:
        """Get the vllm configuration of the vLLM engine."""
        return self.engine.get_vllm_config()

957
958
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
959
        return self.engine.get_model_config()
960

961
962
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
963
        return self.engine.get_parallel_config()
964

965
966
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
967
        return self.engine.get_decoding_config()
968

969
970
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
971
        return self.engine.get_scheduler_config()
972
973
974

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
975
        return self.engine.get_lora_config()
976

977
978
979
980
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
981
        self.engine.do_log_stats()
982

983
    async def check_health(self) -> None:
984
985
986
987
988
989
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

990
        await self.engine.check_health_async()
991
        logger.debug("Health check took %fs", time.perf_counter() - t)
992
993

    async def is_tracing_enabled(self) -> bool:
994
        return self.engine.is_tracing_enabled()
995
996

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
997
        self.engine.add_logger(logger_name=logger_name, logger=logger)
998
999

    def remove_logger(self, logger_name: str) -> None:
1000
        self.engine.remove_logger(logger_name=logger_name)
1001
1002

    async def start_profile(self) -> None:
1003
        self.engine.start_profile()
1004
1005

    async def stop_profile(self) -> None:
1006
        self.engine.stop_profile()
1007

1008
1009
1010
    async def reset_mm_cache(self) -> None:
        self.engine.reset_mm_cache()

1011
1012
1013
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        self.engine.reset_prefix_cache(device)
1014

1015
    async def sleep(self, level: int = 1) -> None:
1016
        await self.reset_prefix_cache()
1017
1018
        self.engine.sleep(level)

1019
1020
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
1021

1022
1023
1024
    async def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

1025
1026
    async def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.engine.add_lora(lora_request)
1027

1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine.collective_rpc_async(method, timeout, args,
                                                      kwargs)

1039
1040

# TODO(v1): Remove this class proxy when V1 goes default.
1041
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
1042
1043
1044
    from vllm.v1.engine.async_llm import AsyncLLM

    AsyncLLMEngine = AsyncLLM  # type: ignore