async_llm_engine.py 45.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import time
6
import weakref
Antoni Baum's avatar
Antoni Baum committed
7
from functools import partial
8
9
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
10
from weakref import ReferenceType
11

12
import vllm.envs as envs
13
14
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
15
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.engine.arg_utils import AsyncEngineArgs
17
from vllm.engine.async_timeout import asyncio_timeout
18
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
19
from vllm.engine.metrics_types import StatLoggerBase
20
from vllm.engine.protocol import EngineClient
21
from vllm.executor.executor_base import ExecutorBase
22
from vllm.inputs import PromptType
23
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.logger import init_logger
25
from vllm.lora.request import LoRARequest
26
from vllm.model_executor.layers.sampler import SamplerOutput
27
from vllm.outputs import PoolingRequestOutput, RequestOutput
28
from vllm.pooling_params import PoolingParams
29
from vllm.sampling_params import SamplingParams
30
from vllm.sequence import ExecuteModelRequest
31
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
32
from vllm.usage.usage_lib import UsageContext
33
from vllm.utils import Device, deprecate_kwargs, weak_bind
34
35

logger = init_logger(__name__)
36
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
37

Antoni Baum's avatar
Antoni Baum committed
38

39
40
41
42
class AsyncEngineDeadError(RuntimeError):
    pass


43
44
45
46
47
48
49
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
50
51

    exception = None
52
    try:
53
54
55
56
57
58
59
60
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
61
62
63
64
65
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
66
            "Task finished unexpectedly. This should never happen! "
67
            "Please open an issue on GitHub. See stack trace above for the "
68
            "actual cause.") from e
69
70


71
72
73
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
74
class AsyncStream:
75
    """A stream of RequestOutputs or PoolingRequestOutputs for a request
76
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
77

78
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
79
        self.request_id = request_id
80
        self._cancel = cancel
81
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
82
83
        self._finished = False

84
    def put(self, item: Union[RequestOutput, PoolingRequestOutput,
85
                              Exception]) -> None:
86
87
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
88

89
90
91
92
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
93
94
95
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
96
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
97
98
99
100
101

    @property
    def finished(self) -> bool:
        return self._finished

102
103
    async def generator(
        self
104
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
105
        try:
106
            while True:
107
                result = await self._queue.get()
108
                if self._is_raisable(result):
109
110
111
112
113
114
115
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
116

117
118
119
120
121
122
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
123

124
125
126
127
128
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
129
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
130
131
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
132
        self.new_requests_event = asyncio.Event()
133
134
135
136

    def __contains__(self, item):
        return item in self._request_streams

137
138
    def __len__(self) -> int:
        return len(self._request_streams)
139
140
141
142
143
144
145

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
146
            self.abort_request(request_id, exception=exc)
147
        else:
148
            # NB: tuple() used here because self.abort_request pops the stream
149
            # out of self._request_streams, so we can't iterate on it directly
150
151
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
152
153

    def process_request_output(self,
154
                               request_output: Union[RequestOutput,
155
                                                     PoolingRequestOutput],
156
157
158
159
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
160
        finished = request_output.finished
161

162
163
164
165
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
166
167
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
168
        if stream is not None:
169
            stream.put(request_output)
170
171
172
173
174
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
175

176
177
    def process_exception(self,
                          request_id: str,
178
                          exception: BaseException,
179
180
181
182
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
183
            logger.info("Finished request %s.", request_id)
184
        self.abort_request(request_id, exception=exception)
185

186
187
188
189
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
190
191
192
193
194
195
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

196
197
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
198
199
200
201
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
202
203
204

        self.new_requests_event.set()

205
206
207
        if verbose:
            logger.info("Added request %s.", request_id)

208
209
        return stream

210
211
212
    def abort_request(self,
                      request_id: str,
                      *,
213
214
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
215
                      verbose: bool = False) -> None:
216
217
        """Abort a request during next background loop iteration."""
        if verbose:
218
            logger.info("Aborted request %s.", request_id)
219

220
        self._aborted_requests.put_nowait(request_id)
221

222
223
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
224
            stream.finish(exception=exception)
225

226
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
227
228
        """Get the new requests and finished requests to be
        sent to the engine."""
229
        new_requests: List[Dict] = []
230
231
        finished_requests: Set[str] = set()

232
233
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
234
235
236
237
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
238
239
            request_id = stream.request_id
            if request_id in finished_requests:
240
                # The request has already been aborted.
241
242
243
244
245
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
246
247

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
248

249
    async def wait_for_new_requests(self):
250
251
252
253
254
255
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
256

Antoni Baum's avatar
Antoni Baum committed
257
258
259
260

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

261
262
263
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

264
    async def step_async(
265
        self, virtual_engine: int
266
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
267
268
269
270
271
272
273
274
275
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
276
277
278
279
280
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
281
282
        allow_async_output_proc = cached_outputs.allow_async_output_proc

283
284
        ctx = self.scheduler_contexts[virtual_engine]

285
286
287
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

288
289
290
291
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
292

293
            # Schedule iteration
294
295
296
297
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

298
299
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
300

301
302
303
304
305
            if not scheduler_outputs.is_empty():
                # this will cause mamba_cache/minimax_cache failed
                # to release finished_requests_ids of the last steps
                finished_requests_ids = self.scheduler[
                    virtual_engine].get_and_reset_finished_requests_ids()
306

307
            # Maybe switch from async mode to sync mode
308
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
309
                self._process_model_outputs(ctx=ctx)
310

311
312
313
314
315
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
316
317
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
318
319
        else:
            finished_requests_ids = list()
320
321
322

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
323

324
        if not scheduler_outputs.is_empty():
325
326
327
328
329
330
331
332

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

333
334
335
336
337
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
338
                virtual_engine=virtual_engine,
339
340
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
341
342
343
344
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
345
346

            if allow_async_output_proc:
347
348
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
349

350
            # Execute the model.
351
            outputs = await self.model_executor.execute_model_async(
352
                execute_model_req)
353

354
355
356
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
357
                self._update_cached_scheduler_output(virtual_engine, outputs)
358
        else:
359
360
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
361
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
362

363
364
365
366
367
368
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
369
            # Clear the cache if we have finished all the steps
370
371
372
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
373

374
375
376
377
378
379
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

380
381
382
383
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
384
385
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
386

387
            if outputs and allow_async_output_proc:
388
                assert len(
389
                    outputs
390
391
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
392
                    outputs[0], seq_group_metadata_list,
393
                    scheduler_outputs.scheduled_seq_groups)
394
395

            if not allow_async_output_proc:
396
                self._process_model_outputs(ctx=ctx)
397
398

                # Log stats.
399
                self.do_log_stats(scheduler_outputs, outputs)
400
401
402
403
404

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
405
            # Multi-step case
406
            return ctx.request_outputs
407
408
409
410

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
411
                self._process_model_outputs(ctx=ctx)
412
            assert len(ctx.output_queue) == 0
413

414
        return ctx.request_outputs
415

416
417
418
419
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

420
421
422
423
424
425
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

426
427
428
429
430
431
432
433
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
434
        priority: int = 0,
435
        data_parallel_rank: Optional[int] = None,
436
        tokenization_kwargs: Optional[dict[str, Any]] = None,
437
    ) -> None:
438
439
440
441
        """
        Async version of
        [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request].
        """
442
443
444
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
445
446
447
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
448
449
        if arrival_time is None:
            arrival_time = time.time()
450

451
452
453
454
        if data_parallel_rank is not None:
            raise ValueError("Targeting data_parallel_rank only supported "
                             "in v1 client.")

455
456
457
458
459
460
461
462
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            # We use the -2 dimension (instead of 0) in case a batched input
            # of batch size 1 is passed in.
            prompt["prompt_token_ids"] = [0
                                          ] * prompt["prompt_embeds"].shape[-2]

463
        processed_inputs = await self.input_preprocessor.preprocess_async(
464
            prompt,
465
            lora_request=lora_request,
466
            tokenization_kwargs=tokenization_kwargs,
467
        )
468
469

        self._add_processed_request(
470
            request_id=request_id,
471
472
473
474
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
475
            trace_headers=trace_headers,
476
            priority=priority,
477
        )
478

479
480
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
481

482
483
484
485
486
487
488
    async def collective_rpc_async(self,
                                   method: str,
                                   timeout: Optional[float] = None,
                                   args: tuple = (),
                                   kwargs: Optional[dict] = None):
        raise NotImplementedError

489

490
class AsyncLLMEngine(EngineClient):
491
    """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
492

493
494
495
496
497
498
    This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to
    make it asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked
    by the generate method when there are requests in the waiting queue. The
    generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine]
    to the caller.
499
500

    Args:
501
        log_requests: Whether to log the requests.
502
503
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
504
505
        *args: Arguments for [`LLMEngine`][vllm.LLMEngine].
        **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine].
506
    """
507

Antoni Baum's avatar
Antoni Baum committed
508
509
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

510
511
512
    def __init__(self,
                 *args,
                 log_requests: bool = True,
513
                 start_engine_loop: bool = True,
514
                 **kwargs) -> None:
515
516
517
518
519
520
521
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

522
        self.log_requests = log_requests
523
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
524

525
526
527
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
528
529
530
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

531
532
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
533
                weak_bind(self.process_request_outputs)
534

535
        self.background_loop: Optional[asyncio.Future] = None
536
537
538
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
539
        self._background_loop_unshielded: Optional[asyncio.Task] = None
540
        self.start_engine_loop = start_engine_loop
541
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
542

543
544
545
        # Lazy initialized fields
        self._request_tracker: RequestTracker

546
547
548
549
550
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

551
    @classmethod
552
553
554
    def _get_executor_cls(cls,
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
        return LLMEngine._get_executor_cls(engine_config)
555

556
    @classmethod
557
558
559
560
561
    @deprecate_kwargs(
        "disable_log_requests",
        additional_message=("This argument will have no effect. "
                            "Use `enable_log_requests` instead."),
    )
562
    def from_vllm_config(
563
564
565
566
567
568
569
570
            cls,
            vllm_config: VllmConfig,
            start_engine_loop: bool = True,
            usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
            stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
            enable_log_requests: bool = False,
            disable_log_stats: bool = False,
            disable_log_requests: bool = True,  # Deprecated, will be removed
571
572
573
574
575
576
577
    ) -> "AsyncLLMEngine":
        """Create an AsyncLLMEngine from the EngineArgs."""

        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            start_engine_loop=start_engine_loop,
578
            log_requests=enable_log_requests,
579
580
581
582
583
            log_stats=not disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

584
585
586
587
588
589
590
591
592
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
593
594
595
596
597
598
599
600
601
602

        vllm_config = engine_args.create_engine_config(usage_context)

        async_engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
            async_engine_cls = V1AsyncLLMEngine

        return async_engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
603
604
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
605
            stat_loggers=stat_loggers,
606
            disable_log_stats=engine_args.disable_log_stats,
607
            enable_log_requests=engine_args.enable_log_requests,
yhu422's avatar
yhu422 committed
608
        )
609

610
611
    @property
    def is_running(self) -> bool:
612
        return (self.background_loop is not None
613
                and self._background_loop_unshielded is not None
614
615
616
617
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
618
619
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
620
621
622
623
624
625
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

626
    @property
627
628
629
630
631
632
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
633

634
635
636
637
638
639
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
640

641
642
643
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

644
645
646
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
647
    ) -> AnyTokenizer:
648
        return await self.engine.get_tokenizer_async(lora_request)
649

650
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
651
        """Start the background loop."""
652
653
654
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
655
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
656
            raise RuntimeError("Background loop is already running.")
657
658
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
659
660

        self._background_loop_unshielded = asyncio.get_event_loop(
661
        ).create_task(self.run_engine_loop(weakref.ref(self)))
662
        self._background_loop_unshielded.add_done_callback(
663
            partial(_log_task_completion, error_callback=self._error_callback))
664
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
665

666
667
668
669
670
671
672
673
674
675
676
677
678
679
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

680
    async def engine_step(self, virtual_engine: int) -> bool:
681
682
683
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
684

685
686
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
687
688
689

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
690
            try:
691
                await self.engine.add_request_async(**new_request)
692
693
694
695
696
697
698
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
699

700
701
        if aborted_requests:
            await self._engine_abort(aborted_requests)
702

703
        request_outputs = await self.engine.step_async(virtual_engine)
704

Antoni Baum's avatar
Antoni Baum committed
705
        # Put the outputs into the corresponding streams.
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
721
        for request_output in request_outputs:
722
            self._request_tracker.process_request_output(
723
                request_output, verbose=self.log_requests)
724
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
725

726
        return all_finished
727

Antoni Baum's avatar
Antoni Baum committed
728
    async def _engine_abort(self, request_ids: Iterable[str]):
729
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
730

731
732
733
734
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
735
        engine: Optional[AsyncLLMEngine] = engine_ref()
736
737
738
        if not engine:
            return

739
        pipeline_parallel_size = \
740
                engine.engine.parallel_config.pipeline_parallel_size
741
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
742
        while True:
743
            if not any(has_requests_in_progress):
744
                logger.debug("Waiting for new requests...")
745
746
747
748
749
750
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
751
752
753
754
755
756
757
758
759
760
761
762
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
763
                logger.debug("Got new requests!")
764
                requests_in_progress = [
765
                    asyncio.create_task(engine.engine_step(ve))
766
767
768
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
769
770
771
772

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
773
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
774
775
776
777
778
779
780
781
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
782
                    has_unfinished_requests = (
783
784
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
785
                            virtual_engine))
786
787
788
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
789
                                engine.engine_step(virtual_engine)))
790
791
792
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
793
794
795
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
796
                engine.set_errored(exc)
797
                raise
Antoni Baum's avatar
Antoni Baum committed
798
799
            await asyncio.sleep(0)

800
    async def add_request(
801
802
803
804
805
806
807
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
808
        priority: int = 0,
809
        data_parallel_rank: Optional[int] = None,
810
        tokenization_kwargs: Optional[dict[str, Any]] = None,
811
    ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
812
        if not self.is_running:
813
814
815
816
817
818
819
820
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
821

822
823
824
825
826
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

827
        stream = self._request_tracker.add_request(
828
            request_id,
829
            verbose=self.log_requests,
830
            prompt=prompt,
831
            params=params,
832
            arrival_time=arrival_time or time.time(),
833
            lora_request=lora_request,
834
            trace_headers=trace_headers,
835
            priority=priority,
836
            data_parallel_rank=data_parallel_rank,
837
            tokenization_kwargs=tokenization_kwargs,
838
        )
Antoni Baum's avatar
Antoni Baum committed
839

840
        return stream.generator()
841

842
    async def generate(
843
        self,
844
        prompt: PromptType,
845
846
        sampling_params: SamplingParams,
        request_id: str,
847
        lora_request: Optional[LoRARequest] = None,
848
        trace_headers: Optional[Mapping[str, str]] = None,
849
        priority: int = 0,
850
        data_parallel_rank: Optional[int] = None,
851
    ) -> AsyncGenerator[RequestOutput, None]:
852
853
854
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
855
856
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
857
858

        Args:
859
860
861
            prompt: The prompt to the LLM. See
                [`PromptType`][vllm.inputs.PromptType] for more details about
                the format of each input.
862
863
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
864
            lora_request: LoRA request to use for generation, if any.
865
            trace_headers: OpenTelemetry trace headers.
866
867
            priority: The priority of the request.
                Only applicable with priority scheduling.
868
869
            data_parallel_rank: The (global) data parallel rank that must
                handle this request. Only applicable if DP is enabled.
870
        Yields:
871
872
            The output `RequestOutput` objects from the LLMEngine
            for the request.
873
874
875
876

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
877
              [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step]
878
879
880
881
882
883
884
885
886
887
888
889
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
890
            >>> # note that engine_args here is AsyncEngineArgs instance
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
917
        """
918
919
920
921
922
923
924
925
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
926
                    data_parallel_rank=data_parallel_rank,
927
928
929
930
931
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
932
933
934

    async def encode(
        self,
935
        prompt: PromptType,
936
937
938
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
939
        trace_headers: Optional[Mapping[str, str]] = None,
940
        priority: int = 0,
941
        tokenization_kwargs: Optional[dict[str, Any]] = None,
942
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
943
        """Generate outputs for a request from a pooling model.
944
945
946
947
948
949

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
950
951
952
            prompt: The prompt to the LLM. See
                [`PromptType`][vllm.inputs.PromptType] for more details about
                the format of each input.
953
954
955
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
956
            trace_headers: OpenTelemetry trace headers.
957
958
            priority: The priority of the request.
                Only applicable with priority scheduling.
959
960

        Yields:
961
            The output `PoolingRequestOutput` objects from the LLMEngine
962
963
964
            for the request.

        Details:
965
966
967
968
969
970
971
972
973
            - If the engine is not running, start the background loop,
                which iteratively invokes
                [`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][]
                to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
                On the next background loop, this request will be sent to
                the underlying engine.
                Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.
974
975

        Example:
976
977
978
        ```
        # Please refer to entrypoints/api_server.py for
        # the complete example.
979

980
981
982
983
984
985
986
        # initialize the engine and the example input
        # note that engine_args here is AsyncEngineArgs instance
        engine = AsyncLLMEngine.from_engine_args(engine_args)
        example_input = {
            "input": "What is LLM?",
            "request_id": 0,
        }
987

988
989
990
991
992
        # start the generation
        results_generator = engine.encode(
        example_input["input"],
        PoolingParams(),
        example_input["request_id"])
993

994
995
996
997
998
999
1000
1001
1002
        # get the results
        final_output = None
        async for request_output in results_generator:
            if await request.is_disconnected():
                # Abort the request if the client disconnects.
                await engine.abort(request_id)
                # Return or raise an error
                ...
            final_output = request_output
1003

1004
1005
1006
        # Process and return the final output
        ...
        ```
1007
        """
1008
1009
1010
1011
1012
1013
1014
1015
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    pooling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
1016
                    tokenization_kwargs=tokenization_kwargs,
1017
1018
1019
1020
1021
            ):
                yield LLMEngine.validate_output(output, PoolingRequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
1022

Antoni Baum's avatar
Antoni Baum committed
1023
1024
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1025

Antoni Baum's avatar
Antoni Baum committed
1026
1027
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1028

Antoni Baum's avatar
Antoni Baum committed
1029
1030
1031
        Args:
            request_id: The unique id of the request.
        """
1032
1033
1034
1035
1036
1037
1038
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1039
        return self._abort(request_id)
1040

Antoni Baum's avatar
Antoni Baum committed
1041
    def _abort(self, request_id: str) -> None:
1042
1043
1044
1045
1046
1047
1048
1049
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1050
        self._request_tracker.abort_request(request_id,
1051
                                            exception=asyncio.CancelledError,
1052
                                            verbose=self.log_requests)
1053

1054
1055
1056
1057
    async def get_vllm_config(self) -> VllmConfig:
        """Get the vllm configuration of the vLLM engine."""
        return self.engine.get_vllm_config()

1058
1059
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1060
        return self.engine.get_model_config()
1061

1062
1063
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1064
        return self.engine.get_parallel_config()
1065

1066
1067
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1068
        return self.engine.get_decoding_config()
1069

1070
1071
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1072
        return self.engine.get_scheduler_config()
1073
1074
1075

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1076
        return self.engine.get_lora_config()
1077

1078
1079
1080
1081
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1082
        self.engine.do_log_stats()
1083

1084
    async def check_health(self) -> None:
1085
1086
1087
1088
1089
1090
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1091
        await self.engine.check_health_async()
1092
        logger.debug("Health check took %fs", time.perf_counter() - t)
1093
1094

    async def is_tracing_enabled(self) -> bool:
1095
        return self.engine.is_tracing_enabled()
1096
1097

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1098
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1099
1100

    def remove_logger(self, logger_name: str) -> None:
1101
        self.engine.remove_logger(logger_name=logger_name)
1102
1103

    async def start_profile(self) -> None:
1104
        self.engine.start_profile()
1105
1106

    async def stop_profile(self) -> None:
1107
        self.engine.stop_profile()
1108

1109
1110
1111
    async def reset_mm_cache(self) -> None:
        self.engine.reset_mm_cache()

1112
1113
1114
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        self.engine.reset_prefix_cache(device)
1115

1116
1117
1118
    async def sleep(self, level: int = 1) -> None:
        self.engine.sleep(level)

1119
1120
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
1121

1122
1123
1124
    async def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

1125
1126
1127
    async def add_lora(self, lora_request: LoRARequest) -> None:
        self.engine.add_lora(lora_request)

1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine.collective_rpc_async(method, timeout, args,
                                                      kwargs)

1139
1140

# TODO(v1): Remove this class proxy when V1 goes default.
1141
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
1142
1143
1144
    from vllm.v1.engine.async_llm import AsyncLLM

    AsyncLLMEngine = AsyncLLM  # type: ignore