async_llm_engine.py 47.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import copy
6
import time
7
import weakref
Antoni Baum's avatar
Antoni Baum committed
8
from functools import partial
9
10
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
11
from weakref import ReferenceType
12

13
import vllm.envs as envs
14
15
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
16
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.engine.arg_utils import AsyncEngineArgs
18
from vllm.engine.async_timeout import asyncio_timeout
19
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
20
from vllm.engine.metrics_types import StatLoggerBase
21
from vllm.engine.protocol import EngineClient
22
from vllm.executor.executor_base import ExecutorBase
23
from vllm.inputs import PromptType
24
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.logger import init_logger
26
from vllm.lora.request import LoRARequest
27
28
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
29
from vllm.model_executor.layers.sampler import SamplerOutput
30
from vllm.outputs import PoolingRequestOutput, RequestOutput
31
from vllm.pooling_params import PoolingParams
32
from vllm.sampling_params import SamplingParams
33
from vllm.sequence import ExecuteModelRequest
34
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
35
from vllm.usage.usage_lib import UsageContext
36
from vllm.utils import Device, weak_bind
37
38

logger = init_logger(__name__)
39
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
40

Antoni Baum's avatar
Antoni Baum committed
41

42
43
44
45
class AsyncEngineDeadError(RuntimeError):
    pass


46
47
48
49
50
51
52
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
53
54

    exception = None
55
    try:
56
57
58
59
60
61
62
63
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
64
65
66
67
68
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
69
            "Task finished unexpectedly. This should never happen! "
70
            "Please open an issue on GitHub. See stack trace above for the "
71
            "actual cause.") from e
72
73


74
75
76
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
77
class AsyncStream:
78
    """A stream of RequestOutputs or PoolingRequestOutputs for a request
79
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
80

81
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
82
        self.request_id = request_id
83
        self._cancel = cancel
84
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
85
86
        self._finished = False

87
    def put(self, item: Union[RequestOutput, PoolingRequestOutput,
88
                              Exception]) -> None:
89
90
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
91

92
93
94
95
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
96
97
98
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
99
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
100
101
102
103
104

    @property
    def finished(self) -> bool:
        return self._finished

105
106
    async def generator(
        self
107
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
108
        try:
109
            while True:
110
                result = await self._queue.get()
111
                if self._is_raisable(result):
112
113
114
115
116
117
118
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
119

120
121
122
123
124
125
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
126

127
128
129
130
131
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
132
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
133
134
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
135
        self.new_requests_event = asyncio.Event()
136
137
138
139

    def __contains__(self, item):
        return item in self._request_streams

140
141
    def __len__(self) -> int:
        return len(self._request_streams)
142
143
144
145
146
147
148

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
149
            self.abort_request(request_id, exception=exc)
150
        else:
151
            # NB: tuple() used here because self.abort_request pops the stream
152
            # out of self._request_streams, so we can't iterate on it directly
153
154
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
155
156

    def process_request_output(self,
157
                               request_output: Union[RequestOutput,
158
                                                     PoolingRequestOutput],
159
160
161
162
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
163
        finished = request_output.finished
164

165
166
167
168
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
169
170
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
171
        if stream is not None:
172
            stream.put(request_output)
173
174
175
176
177
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
178

179
180
    def process_exception(self,
                          request_id: str,
181
                          exception: BaseException,
182
183
184
185
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
186
            logger.info("Finished request %s.", request_id)
187
        self.abort_request(request_id, exception=exception)
188

189
190
191
192
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
193
194
195
196
197
198
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

199
200
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
201
202
203
204
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
205
206
207

        self.new_requests_event.set()

208
209
210
        if verbose:
            logger.info("Added request %s.", request_id)

211
212
        return stream

213
214
215
    def abort_request(self,
                      request_id: str,
                      *,
216
217
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
218
                      verbose: bool = False) -> None:
219
220
        """Abort a request during next background loop iteration."""
        if verbose:
221
            logger.info("Aborted request %s.", request_id)
222

223
        self._aborted_requests.put_nowait(request_id)
224

225
226
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
227
            stream.finish(exception=exception)
228

229
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
230
231
        """Get the new requests and finished requests to be
        sent to the engine."""
232
        new_requests: List[Dict] = []
233
234
        finished_requests: Set[str] = set()

235
236
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
237
238
239
240
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
241
242
            request_id = stream.request_id
            if request_id in finished_requests:
243
                # The request has already been aborted.
244
245
246
247
248
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
249
250

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
251

252
    async def wait_for_new_requests(self):
253
254
255
256
257
258
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
259

Antoni Baum's avatar
Antoni Baum committed
260
261
262
263

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

264
265
266
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

267
    async def step_async(
268
        self, virtual_engine: int
269
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
270
271
272
273
274
275
276
277
278
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
279
280
281
282
283
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
284
285
        allow_async_output_proc = cached_outputs.allow_async_output_proc

286
287
        ctx = self.scheduler_contexts[virtual_engine]

288
289
290
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

291
292
293
294
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
295

296
            # Schedule iteration
297
298
299
300
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

301
302
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
303

304
305
306
307
308
            if not scheduler_outputs.is_empty():
                # this will cause mamba_cache/minimax_cache failed
                # to release finished_requests_ids of the last steps
                finished_requests_ids = self.scheduler[
                    virtual_engine].get_and_reset_finished_requests_ids()
309

310
            # Maybe switch from async mode to sync mode
311
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
312
                self._process_model_outputs(ctx=ctx)
313

314
315
316
317
318
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
319
320
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
321
322
        else:
            finished_requests_ids = list()
323
324
325

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
326

327
        if not scheduler_outputs.is_empty():
328
329
330
331
332
333
334
335

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

336
337
338
339
340
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
341
                virtual_engine=virtual_engine,
342
343
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
344
345
346
347
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
348
349

            if allow_async_output_proc:
350
351
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
352

353
            # Execute the model.
354
            outputs = await self.model_executor.execute_model_async(
355
                execute_model_req)
356

357
358
359
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
360
                self._update_cached_scheduler_output(virtual_engine, outputs)
361
        else:
362
363
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
364
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
365

366
367
368
369
370
371
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
372
            # Clear the cache if we have finished all the steps
373
374
375
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
376

377
378
379
380
381
382
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

383
384
385
386
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
387
388
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
389

390
            if outputs and allow_async_output_proc:
391
                assert len(
392
                    outputs
393
394
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
395
                    outputs[0], seq_group_metadata_list,
396
                    scheduler_outputs.scheduled_seq_groups)
397
398

            if not allow_async_output_proc:
399
                self._process_model_outputs(ctx=ctx)
400
401

                # Log stats.
402
                self.do_log_stats(scheduler_outputs, outputs)
403
404
405
406
407

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
408
            # Multi-step case
409
            return ctx.request_outputs
410
411
412
413

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
414
                self._process_model_outputs(ctx=ctx)
415
            assert len(ctx.output_queue) == 0
416

417
        return ctx.request_outputs
418

419
420
421
422
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

423
424
425
426
427
428
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

429
430
431
432
433
434
435
436
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
437
        priority: int = 0,
438
        data_parallel_rank: Optional[int] = None,
439
        tokenization_kwargs: Optional[dict[str, Any]] = None,
440
    ) -> None:
441
442
443
444
        """
        Async version of
        [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request].
        """
445
446
447
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
448
449
450
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
451
452
        if arrival_time is None:
            arrival_time = time.time()
453

454
455
456
457
        if data_parallel_rank is not None:
            raise ValueError("Targeting data_parallel_rank only supported "
                             "in v1 client.")

458
459
460
461
462
463
464
465
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            # We use the -2 dimension (instead of 0) in case a batched input
            # of batch size 1 is passed in.
            prompt["prompt_token_ids"] = [0
                                          ] * prompt["prompt_embeds"].shape[-2]

466
        processed_inputs = await self.input_preprocessor.preprocess_async(
467
            prompt,
468
            lora_request=lora_request,
469
            tokenization_kwargs=tokenization_kwargs,
470
        )
471

472
473
474
475
476
477
478
479
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
480
                tokenizer=await self.get_tokenizer_async(lora_request),
481
                default_guided_backend=self.decoding_config.backend,
482
                reasoning_backend=self.decoding_config.reasoning_backend,
483
                model_config=self.model_config)
484

485
        self._add_processed_request(
486
            request_id=request_id,
487
488
489
490
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
491
            trace_headers=trace_headers,
492
            priority=priority,
493
        )
494

495
496
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
497

498
499
500
501
502
503
504
    async def collective_rpc_async(self,
                                   method: str,
                                   timeout: Optional[float] = None,
                                   args: tuple = (),
                                   kwargs: Optional[dict] = None):
        raise NotImplementedError

505

506
507
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
508
        default_guided_backend: str, reasoning_backend: Optional[str],
509
        model_config: ModelConfig) -> SamplingParams:
510
511
512
513
514
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
515
    if sampling_params.guided_decoding is None:
516
517
        return sampling_params

518
519
520
521
522
    # Defensively copy sampling params since guided decoding logits
    # processors can have different state for each request
    sampling_params = copy.copy(sampling_params)
    guided_decoding = sampling_params.guided_decoding

523
    logger.debug(
524
525
526
527
        "Building guided decoding logits processor. "
        "guided_decoding: %s%s", guided_decoding,
        f", reasoning_backend: {reasoning_backend}"
        if reasoning_backend is not None else "")
528
529
530
531

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
532
533
        guided_params=guided_decoding,
        tokenizer=tokenizer,
534
        reasoning_backend=reasoning_backend,
535
        model_config=model_config)
536
537
538
539
540
541
542
543
544
545
546
547

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


548
class AsyncLLMEngine(EngineClient):
549
    """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
550

551
552
553
554
555
556
    This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to
    make it asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked
    by the generate method when there are requests in the waiting queue. The
    generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine]
    to the caller.
557
558

    Args:
559
        log_requests: Whether to log the requests.
560
561
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
562
563
        *args: Arguments for [`LLMEngine`][vllm.LLMEngine].
        **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine].
564
    """
565

Antoni Baum's avatar
Antoni Baum committed
566
567
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

568
569
570
    def __init__(self,
                 *args,
                 log_requests: bool = True,
571
                 start_engine_loop: bool = True,
572
                 **kwargs) -> None:
573
574
575
576
577
578
579
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

580
        self.log_requests = log_requests
581
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
582

583
584
585
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
586
587
588
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

589
590
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
591
                weak_bind(self.process_request_outputs)
592

593
        self.background_loop: Optional[asyncio.Future] = None
594
595
596
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
597
        self._background_loop_unshielded: Optional[asyncio.Task] = None
598
        self.start_engine_loop = start_engine_loop
599
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
600

601
602
603
        # Lazy initialized fields
        self._request_tracker: RequestTracker

604
605
606
607
608
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

609
    @classmethod
610
611
612
    def _get_executor_cls(cls,
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
        return LLMEngine._get_executor_cls(engine_config)
613

614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
        disable_log_requests: bool = False,
        disable_log_stats: bool = False,
    ) -> "AsyncLLMEngine":
        """Create an AsyncLLMEngine from the EngineArgs."""

        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            start_engine_loop=start_engine_loop,
            log_requests=not disable_log_requests,
            log_stats=not disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

636
637
638
639
640
641
642
643
644
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
645
646
647
648
649
650
651
652
653
654

        vllm_config = engine_args.create_engine_config(usage_context)

        async_engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
            async_engine_cls = V1AsyncLLMEngine

        return async_engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
655
656
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
657
            stat_loggers=stat_loggers,
658
659
            disable_log_stats=engine_args.disable_log_stats,
            disable_log_requests=engine_args.disable_log_requests,
yhu422's avatar
yhu422 committed
660
        )
661

662
663
    @property
    def is_running(self) -> bool:
664
        return (self.background_loop is not None
665
                and self._background_loop_unshielded is not None
666
667
668
669
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
670
671
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
672
673
674
675
676
677
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

678
    @property
679
680
681
682
683
684
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
685

686
687
688
689
690
691
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
692

693
694
695
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

696
697
698
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
699
    ) -> AnyTokenizer:
700
        return await self.engine.get_tokenizer_async(lora_request)
701

702
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
703
        """Start the background loop."""
704
705
706
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
707
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
708
            raise RuntimeError("Background loop is already running.")
709
710
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
711
712

        self._background_loop_unshielded = asyncio.get_event_loop(
713
        ).create_task(self.run_engine_loop(weakref.ref(self)))
714
        self._background_loop_unshielded.add_done_callback(
715
            partial(_log_task_completion, error_callback=self._error_callback))
716
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
717

718
719
720
721
722
723
724
725
726
727
728
729
730
731
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

732
    async def engine_step(self, virtual_engine: int) -> bool:
733
734
735
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
736

737
738
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
739
740
741

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
742
            try:
743
                await self.engine.add_request_async(**new_request)
744
745
746
747
748
749
750
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
751

752
753
        if aborted_requests:
            await self._engine_abort(aborted_requests)
754

755
        request_outputs = await self.engine.step_async(virtual_engine)
756

Antoni Baum's avatar
Antoni Baum committed
757
        # Put the outputs into the corresponding streams.
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
773
        for request_output in request_outputs:
774
            self._request_tracker.process_request_output(
775
                request_output, verbose=self.log_requests)
776
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
777

778
        return all_finished
779

Antoni Baum's avatar
Antoni Baum committed
780
    async def _engine_abort(self, request_ids: Iterable[str]):
781
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
782

783
784
785
786
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
787
        engine: Optional[AsyncLLMEngine] = engine_ref()
788
789
790
        if not engine:
            return

791
        pipeline_parallel_size = \
792
                engine.engine.parallel_config.pipeline_parallel_size
793
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
794
        while True:
795
            if not any(has_requests_in_progress):
796
                logger.debug("Waiting for new requests...")
797
798
799
800
801
802
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
803
804
805
806
807
808
809
810
811
812
813
814
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
815
                logger.debug("Got new requests!")
816
                requests_in_progress = [
817
                    asyncio.create_task(engine.engine_step(ve))
818
819
820
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
821
822
823
824

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
825
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
826
827
828
829
830
831
832
833
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
834
                    has_unfinished_requests = (
835
836
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
837
                            virtual_engine))
838
839
840
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
841
                                engine.engine_step(virtual_engine)))
842
843
844
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
845
846
847
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
848
                engine.set_errored(exc)
849
                raise
Antoni Baum's avatar
Antoni Baum committed
850
851
            await asyncio.sleep(0)

852
    async def add_request(
853
854
855
856
857
858
859
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
860
        priority: int = 0,
861
        data_parallel_rank: Optional[int] = None,
862
        tokenization_kwargs: Optional[dict[str, Any]] = None,
863
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
864
        if not self.is_running:
865
866
867
868
869
870
871
872
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
873

874
875
876
877
878
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

879
        stream = self._request_tracker.add_request(
880
            request_id,
881
            verbose=self.log_requests,
882
            prompt=prompt,
883
            params=params,
884
            arrival_time=arrival_time or time.time(),
885
            lora_request=lora_request,
886
            trace_headers=trace_headers,
887
            priority=priority,
888
            data_parallel_rank=data_parallel_rank,
889
            tokenization_kwargs=tokenization_kwargs,
890
        )
Antoni Baum's avatar
Antoni Baum committed
891

892
        return stream.generator()
893

894
    async def generate(
895
        self,
896
        prompt: PromptType,
897
898
        sampling_params: SamplingParams,
        request_id: str,
899
        lora_request: Optional[LoRARequest] = None,
900
        trace_headers: Optional[Mapping[str, str]] = None,
901
        priority: int = 0,
902
        data_parallel_rank: Optional[int] = None,
903
    ) -> AsyncGenerator[RequestOutput, None]:
904
905
906
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
907
908
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
909
910

        Args:
911
912
913
            prompt: The prompt to the LLM. See
                [`PromptType`][vllm.inputs.PromptType] for more details about
                the format of each input.
914
915
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
916
            lora_request: LoRA request to use for generation, if any.
917
            trace_headers: OpenTelemetry trace headers.
918
919
            priority: The priority of the request.
                Only applicable with priority scheduling.
920
921
            data_parallel_rank: The (global) data parallel rank that must
                handle this request. Only applicable if DP is enabled.
922
        Yields:
923
924
            The output `RequestOutput` objects from the LLMEngine
            for the request.
925
926
927
928

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
929
              [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step]
930
931
932
933
934
935
936
937
938
939
940
941
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
942
            >>> # note that engine_args here is AsyncEngineArgs instance
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
969
        """
970
971
972
973
974
975
976
977
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
978
                    data_parallel_rank=data_parallel_rank,
979
980
981
982
983
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
984
985
986

    async def encode(
        self,
987
        prompt: PromptType,
988
989
990
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
991
        trace_headers: Optional[Mapping[str, str]] = None,
992
        priority: int = 0,
993
        tokenization_kwargs: Optional[dict[str, Any]] = None,
994
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
995
        """Generate outputs for a request from a pooling model.
996
997
998
999
1000
1001

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1002
1003
1004
            prompt: The prompt to the LLM. See
                [`PromptType`][vllm.inputs.PromptType] for more details about
                the format of each input.
1005
1006
1007
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1008
            trace_headers: OpenTelemetry trace headers.
1009
1010
            priority: The priority of the request.
                Only applicable with priority scheduling.
1011
1012

        Yields:
1013
            The output `PoolingRequestOutput` objects from the LLMEngine
1014
1015
1016
            for the request.

        Details:
1017
1018
1019
1020
1021
1022
1023
1024
1025
            - If the engine is not running, start the background loop,
                which iteratively invokes
                [`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][]
                to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
                On the next background loop, this request will be sent to
                the underlying engine.
                Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.
1026
1027

        Example:
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        ```
        # Please refer to entrypoints/api_server.py for
        # the complete example.
    
        # initialize the engine and the example input
        # note that engine_args here is AsyncEngineArgs instance
        engine = AsyncLLMEngine.from_engine_args(engine_args)
        example_input = {
            "input": "What is LLM?",
            "request_id": 0,
        }
    
        # start the generation
        results_generator = engine.encode(
        example_input["input"],
        PoolingParams(),
        example_input["request_id"])
    
        # get the results
        final_output = None
        async for request_output in results_generator:
            if await request.is_disconnected():
                # Abort the request if the client disconnects.
                await engine.abort(request_id)
                # Return or raise an error
                ...
            final_output = request_output
    
        # Process and return the final output
        ...
        ```
1059
        """
1060
1061
1062
1063
1064
1065
1066
1067
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    pooling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
1068
                    tokenization_kwargs=tokenization_kwargs,
1069
1070
1071
1072
1073
            ):
                yield LLMEngine.validate_output(output, PoolingRequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1074

Antoni Baum's avatar
Antoni Baum committed
1075
1076
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1077

Antoni Baum's avatar
Antoni Baum committed
1078
1079
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1080

Antoni Baum's avatar
Antoni Baum committed
1081
1082
1083
        Args:
            request_id: The unique id of the request.
        """
1084
1085
1086
1087
1088
1089
1090
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1091
        return self._abort(request_id)
1092

Antoni Baum's avatar
Antoni Baum committed
1093
    def _abort(self, request_id: str) -> None:
1094
1095
1096
1097
1098
1099
1100
1101
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1102
        self._request_tracker.abort_request(request_id,
1103
                                            exception=asyncio.CancelledError,
1104
                                            verbose=self.log_requests)
1105

1106
1107
1108
1109
    async def get_vllm_config(self) -> VllmConfig:
        """Get the vllm configuration of the vLLM engine."""
        return self.engine.get_vllm_config()

1110
1111
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1112
        return self.engine.get_model_config()
1113

1114
1115
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1116
        return self.engine.get_parallel_config()
1117

1118
1119
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1120
        return self.engine.get_decoding_config()
1121

1122
1123
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1124
        return self.engine.get_scheduler_config()
1125
1126
1127

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1128
        return self.engine.get_lora_config()
1129

1130
1131
1132
1133
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1134
        self.engine.do_log_stats()
1135

1136
    async def check_health(self) -> None:
1137
1138
1139
1140
1141
1142
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1143
        await self.engine.check_health_async()
1144
        logger.debug("Health check took %fs", time.perf_counter() - t)
1145
1146

    async def is_tracing_enabled(self) -> bool:
1147
        return self.engine.is_tracing_enabled()
1148
1149

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1150
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1151
1152

    def remove_logger(self, logger_name: str) -> None:
1153
        self.engine.remove_logger(logger_name=logger_name)
1154
1155

    async def start_profile(self) -> None:
1156
        self.engine.start_profile()
1157
1158

    async def stop_profile(self) -> None:
1159
        self.engine.stop_profile()
1160

1161
1162
1163
    async def reset_mm_cache(self) -> None:
        self.engine.reset_mm_cache()

1164
1165
1166
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        self.engine.reset_prefix_cache(device)
1167

1168
1169
1170
    async def sleep(self, level: int = 1) -> None:
        self.engine.sleep(level)

1171
1172
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
1173

1174
1175
1176
    async def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

1177
1178
1179
    async def add_lora(self, lora_request: LoRARequest) -> None:
        self.engine.add_lora(lora_request)

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine.collective_rpc_async(method, timeout, args,
                                                      kwargs)

1191
1192

# TODO(v1): Remove this class proxy when V1 goes default.
1193
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
1194
1195
1196
    from vllm.v1.engine.async_llm import AsyncLLM

    AsyncLLMEngine = AsyncLLM  # type: ignore