"tests/kernels/untest_flashinfer.py" did not exist on "6ffa3f314c59e42238f1c5f923ff2839e0af9698"
async_llm_engine.py 40.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import time
6
import weakref
Antoni Baum's avatar
Antoni Baum committed
7
from functools import partial
8
9
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
10
from weakref import ReferenceType
11

12
import vllm.envs as envs
13
from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig,
14
                         SchedulerConfig, VllmConfig)
15
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.engine.arg_utils import AsyncEngineArgs
17
from vllm.engine.async_timeout import asyncio_timeout
18
from vllm.engine.llm_engine import LLMEngine
19
from vllm.engine.metrics_types import StatLoggerBase
20
from vllm.engine.protocol import EngineClient
21
from vllm.executor.executor_base import ExecutorBase
22
from vllm.inputs import PromptType
23
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.logger import init_logger
25
from vllm.lora.request import LoRARequest
26
from vllm.model_executor.layers.sampler import SamplerOutput
27
from vllm.outputs import PoolingRequestOutput, RequestOutput
28
from vllm.pooling_params import PoolingParams
29
from vllm.sampling_params import SamplingParams
30
from vllm.sequence import ExecuteModelRequest
31
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
32
from vllm.usage.usage_lib import UsageContext
33
from vllm.utils import Device, deprecate_kwargs, weak_bind
34
35

logger = init_logger(__name__)
36
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
37

Antoni Baum's avatar
Antoni Baum committed
38

39
40
41
42
class AsyncEngineDeadError(RuntimeError):
    pass


43
44
45
46
47
48
49
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
50
51

    exception = None
52
    try:
53
54
55
56
57
58
59
60
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
61
62
63
64
65
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
66
            "Task finished unexpectedly. This should never happen! "
67
            "Please open an issue on GitHub. See stack trace above for the "
68
            "actual cause.") from e
69
70


71
72
73
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
74
class AsyncStream:
75
76
    """A stream of RequestOutputs for a request that can be iterated over
    asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
77

78
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
79
        self.request_id = request_id
80
        self._cancel = cancel
81
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
82
83
        self._finished = False

84
    def put(self, item: Union[RequestOutput, Exception]) -> None:
85
86
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
87

88
89
90
91
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
92
93
94
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
95
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
96
97
98
99
100

    @property
    def finished(self) -> bool:
        return self._finished

101
    async def generator(self) -> AsyncGenerator[RequestOutput, None]:
102
        try:
103
            while True:
104
                result = await self._queue.get()
105
                if self._is_raisable(result):
106
107
108
109
110
111
112
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
113

114
115
116
117
118
119
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
120

121
122
123
124
125
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
126
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
127
128
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
129
        self.new_requests_event = asyncio.Event()
130
131
132
133

    def __contains__(self, item):
        return item in self._request_streams

134
135
    def __len__(self) -> int:
        return len(self._request_streams)
136
137
138
139
140
141
142

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
143
            self.abort_request(request_id, exception=exc)
144
        else:
145
            # NB: tuple() used here because self.abort_request pops the stream
146
            # out of self._request_streams, so we can't iterate on it directly
147
148
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
149
150

    def process_request_output(self,
151
                               request_output: RequestOutput,
152
153
154
155
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
156
        finished = request_output.finished
157

158
159
160
161
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
162
163
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
164
        if stream is not None:
165
            stream.put(request_output)
166
167
168
169
170
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
171

172
173
    def process_exception(self,
                          request_id: str,
174
                          exception: BaseException,
175
176
177
178
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
179
            logger.info("Finished request %s.", request_id)
180
        self.abort_request(request_id, exception=exception)
181

182
183
184
185
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
186
187
188
189
190
191
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

192
193
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
194
195
196
197
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
198
199
200

        self.new_requests_event.set()

201
202
203
        if verbose:
            logger.info("Added request %s.", request_id)

204
205
        return stream

206
207
208
    def abort_request(self,
                      request_id: str,
                      *,
209
210
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
211
                      verbose: bool = False) -> None:
212
213
        """Abort a request during next background loop iteration."""
        if verbose:
214
            logger.info("Aborted request %s.", request_id)
215

216
        self._aborted_requests.put_nowait(request_id)
217

218
219
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
220
            stream.finish(exception=exception)
221

222
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
223
224
        """Get the new requests and finished requests to be
        sent to the engine."""
225
        new_requests: List[Dict] = []
226
227
        finished_requests: Set[str] = set()

228
229
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
230
231
232
233
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
234
235
            request_id = stream.request_id
            if request_id in finished_requests:
236
                # The request has already been aborted.
237
238
239
240
241
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
242
243

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
244

245
    async def wait_for_new_requests(self):
246
247
248
249
250
251
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
252

Antoni Baum's avatar
Antoni Baum committed
253
254
255
256

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

257
258
259
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

260
    async def step_async(self, virtual_engine: int) -> List[RequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
261
262
263
264
265
266
267
268
269
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
270
271
272
273
274
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
275
276
        allow_async_output_proc = cached_outputs.allow_async_output_proc

277
278
        ctx = self.scheduler_contexts[virtual_engine]

279
280
281
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

282
283
284
285
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
286

287
            # Schedule iteration
288
289
290
291
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

292
293
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
294

295
296
297
298
299
            if not scheduler_outputs.is_empty():
                # this will cause mamba_cache/minimax_cache failed
                # to release finished_requests_ids of the last steps
                finished_requests_ids = self.scheduler[
                    virtual_engine].get_and_reset_finished_requests_ids()
300

301
            # Maybe switch from async mode to sync mode
302
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
303
                self._process_model_outputs(ctx=ctx)
304

305
306
        else:
            finished_requests_ids = list()
307
308
309

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
310

311
        if not scheduler_outputs.is_empty():
312
313
314
315
316
317
318
319

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

320
321
322
323
324
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
325
                virtual_engine=virtual_engine,
326
327
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
328
329
330
331
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
332
333

            if allow_async_output_proc:
334
335
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
336

337
            # Execute the model.
338
            outputs = await self.model_executor.execute_model_async(
339
                execute_model_req)
340

341
        else:
342
343
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
344
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
345

346
        if not self._has_remaining_steps(seq_group_metadata_list):
347
            # is_first_step_output is True only when the num_steps of all
348
            # the sequences are 1.
349
350
351
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

352
353
354
355
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
356
357
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
358

359
            if outputs and allow_async_output_proc:
360
                assert len(
361
                    outputs
362
363
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
364
                    outputs[0], seq_group_metadata_list,
365
                    scheduler_outputs.scheduled_seq_groups)
366
367

            if not allow_async_output_proc:
368
                self._process_model_outputs(ctx=ctx)
369
370

                # Log stats.
371
                self.do_log_stats(scheduler_outputs, outputs)
372
373
374
375
376

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
377
            # Multi-step case
378
            return ctx.request_outputs
379
380
381
382

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
383
                self._process_model_outputs(ctx=ctx)
384
            assert len(ctx.output_queue) == 0
385

386
        return ctx.request_outputs
387

388
389
390
391
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

392
393
    async def get_tokenizer_async(self) -> AnyTokenizer:
        return self.get_tokenizer()
394

395
396
397
398
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
399
        params: SamplingParams,
400
401
402
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
403
        priority: int = 0,
404
        data_parallel_rank: Optional[int] = None,
405
        tokenization_kwargs: Optional[dict[str, Any]] = None,
406
    ) -> None:
407
408
409
410
        """
        Async version of
        [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request].
        """
411
412
413
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
414
415
416
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
417
418
        if arrival_time is None:
            arrival_time = time.time()
419

420
421
422
423
        if data_parallel_rank is not None:
            raise ValueError("Targeting data_parallel_rank only supported "
                             "in v1 client.")

424
425
426
427
428
429
430
431
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            # We use the -2 dimension (instead of 0) in case a batched input
            # of batch size 1 is passed in.
            prompt["prompt_token_ids"] = [0
                                          ] * prompt["prompt_embeds"].shape[-2]

432
        processed_inputs = await self.input_preprocessor.preprocess_async(
433
            prompt,
434
            tokenization_kwargs=tokenization_kwargs,
435
        )
436
437

        self._add_processed_request(
438
            request_id=request_id,
439
440
441
442
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
443
            trace_headers=trace_headers,
444
            priority=priority,
445
        )
446

447
448
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
449

450
451
452
453
454
455
456
    async def collective_rpc_async(self,
                                   method: str,
                                   timeout: Optional[float] = None,
                                   args: tuple = (),
                                   kwargs: Optional[dict] = None):
        raise NotImplementedError

457

458
class AsyncLLMEngine(EngineClient):
459
    """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
460

461
462
463
464
465
466
    This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to
    make it asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked
    by the generate method when there are requests in the waiting queue. The
    generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine]
    to the caller.
467
468

    Args:
469
        log_requests: Whether to log the requests.
470
471
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
472
473
        *args: Arguments for [`LLMEngine`][vllm.LLMEngine].
        **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine].
474
    """
475

Antoni Baum's avatar
Antoni Baum committed
476
477
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

478
    def __init__(self,
479
                 *args: Any,
480
                 log_requests: bool = True,
481
                 start_engine_loop: bool = True,
482
                 **kwargs: Any) -> None:
483
484
485
486
487
488
489
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

490
        self.log_requests = log_requests
491
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
492

493
494
495
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
496
497
498
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

499
500
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
501
                weak_bind(self.process_request_outputs)
502

503
        self.background_loop: Optional[asyncio.Future] = None
504
505
506
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
507
        self._background_loop_unshielded: Optional[asyncio.Task] = None
508
        self.start_engine_loop = start_engine_loop
509
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
510

511
512
513
        # Lazy initialized fields
        self._request_tracker: RequestTracker

514
515
516
517
518
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

519
    @classmethod
520
521
522
    def _get_executor_cls(cls,
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
        return LLMEngine._get_executor_cls(engine_config)
523

524
    @classmethod
525
526
527
528
529
    @deprecate_kwargs(
        "disable_log_requests",
        additional_message=("This argument will have no effect. "
                            "Use `enable_log_requests` instead."),
    )
530
    def from_vllm_config(
531
532
533
534
535
536
537
538
            cls,
            vllm_config: VllmConfig,
            start_engine_loop: bool = True,
            usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
            stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
            enable_log_requests: bool = False,
            disable_log_stats: bool = False,
            disable_log_requests: bool = True,  # Deprecated, will be removed
539
540
541
542
543
544
545
    ) -> "AsyncLLMEngine":
        """Create an AsyncLLMEngine from the EngineArgs."""

        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            start_engine_loop=start_engine_loop,
546
            log_requests=enable_log_requests,
547
548
549
550
551
            log_stats=not disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

552
553
554
555
556
557
558
559
560
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
561
562
563
564
565
566
567
568
569
570

        vllm_config = engine_args.create_engine_config(usage_context)

        async_engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
            async_engine_cls = V1AsyncLLMEngine

        return async_engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
571
572
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
573
            stat_loggers=stat_loggers,
574
            disable_log_stats=engine_args.disable_log_stats,
575
            enable_log_requests=engine_args.enable_log_requests,
yhu422's avatar
yhu422 committed
576
        )
577

578
579
    @property
    def is_running(self) -> bool:
580
        return (self.background_loop is not None
581
                and self._background_loop_unshielded is not None
582
583
584
585
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
586
587
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
588
589
590
591
592
593
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

594
    @property
595
596
597
598
599
600
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
601

602
603
604
605
606
607
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
608

609
610
611
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

612
613
    async def get_tokenizer(self) -> AnyTokenizer:
        return self.engine.get_tokenizer()
614

615
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
616
        """Start the background loop."""
617
618
619
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
620
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
621
            raise RuntimeError("Background loop is already running.")
622
623
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
624
625

        self._background_loop_unshielded = asyncio.get_event_loop(
626
        ).create_task(self.run_engine_loop(weakref.ref(self)))
627
        self._background_loop_unshielded.add_done_callback(
628
            partial(_log_task_completion, error_callback=self._error_callback))
629
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
630

631
632
633
634
635
636
637
638
639
640
641
642
643
644
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

645
    async def engine_step(self, virtual_engine: int) -> bool:
646
647
648
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
649

650
651
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
652
653
654

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
655
            try:
656
                await self.engine.add_request_async(**new_request)
657
658
659
660
661
662
663
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
664

665
666
        if aborted_requests:
            await self._engine_abort(aborted_requests)
667

668
        request_outputs = await self.engine.step_async(virtual_engine)
669

Antoni Baum's avatar
Antoni Baum committed
670
        # Put the outputs into the corresponding streams.
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
686
        for request_output in request_outputs:
687
            self._request_tracker.process_request_output(
688
                request_output, verbose=self.log_requests)
689
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
690

691
        return all_finished
692

Antoni Baum's avatar
Antoni Baum committed
693
    async def _engine_abort(self, request_ids: Iterable[str]):
694
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
695

696
697
698
699
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
700
        engine: Optional[AsyncLLMEngine] = engine_ref()
701
702
703
        if not engine:
            return

704
        pipeline_parallel_size = \
705
                engine.engine.parallel_config.pipeline_parallel_size
706
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
707
        while True:
708
            if not any(has_requests_in_progress):
709
                logger.debug("Waiting for new requests...")
710
711
712
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
713
                # time out, and unblocks the RPC thread in the workers so that
714
715
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
716
717
718
719
720
721
722
723
724
725
726
727
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
728
                logger.debug("Got new requests!")
729
                requests_in_progress = [
730
                    asyncio.create_task(engine.engine_step(ve))
731
732
733
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
734
735
736
737

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
738
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
739
740
741
742
743
744
745
746
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
747
                    has_unfinished_requests = (
748
749
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
750
                            virtual_engine))
751
752
753
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
754
                                engine.engine_step(virtual_engine)))
755
756
757
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
758
759
760
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
761
                engine.set_errored(exc)
762
                raise
Antoni Baum's avatar
Antoni Baum committed
763
764
            await asyncio.sleep(0)

765
    async def add_request(
766
767
768
        self,
        request_id: str,
        prompt: PromptType,
769
        params: SamplingParams,
770
771
772
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
773
        priority: int = 0,
774
        data_parallel_rank: Optional[int] = None,
775
        tokenization_kwargs: Optional[dict[str, Any]] = None,
776
    ) -> AsyncGenerator[RequestOutput, None]:
777
        if not self.is_running:
778
779
780
781
782
783
784
785
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
786

787
788
789
790
791
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

792
        stream = self._request_tracker.add_request(
793
            request_id,
794
            verbose=self.log_requests,
795
            prompt=prompt,
796
            params=params,
797
            arrival_time=arrival_time or time.time(),
798
            lora_request=lora_request,
799
            trace_headers=trace_headers,
800
            priority=priority,
801
            data_parallel_rank=data_parallel_rank,
802
            tokenization_kwargs=tokenization_kwargs,
803
        )
Antoni Baum's avatar
Antoni Baum committed
804

805
        return stream.generator()
806

807
    async def generate(
808
        self,
809
        prompt: PromptType,
810
811
        sampling_params: SamplingParams,
        request_id: str,
812
        lora_request: Optional[LoRARequest] = None,
813
        trace_headers: Optional[Mapping[str, str]] = None,
814
        priority: int = 0,
815
        data_parallel_rank: Optional[int] = None,
816
    ) -> AsyncGenerator[RequestOutput, None]:
817
818
819
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
820
821
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
822
823

        Args:
824
825
826
            prompt: The prompt to the LLM. See
                [`PromptType`][vllm.inputs.PromptType] for more details about
                the format of each input.
827
828
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
829
            lora_request: LoRA request to use for generation, if any.
830
            trace_headers: OpenTelemetry trace headers.
831
832
            priority: The priority of the request.
                Only applicable with priority scheduling.
833
834
            data_parallel_rank: The (global) data parallel rank that must
                handle this request. Only applicable if DP is enabled.
835
        Yields:
836
837
            The output `RequestOutput` objects from the LLMEngine
            for the request.
838
839
840
841

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
842
              [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step]
843
844
845
846
847
848
849
850
851
852
853
854
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
855
            >>> # note that engine_args here is AsyncEngineArgs instance
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
882
        """
883
884
885
886
887
888
889
890
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
891
                    data_parallel_rank=data_parallel_rank,
892
893
894
895
896
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
897

898
    def encode(
899
        self,
900
        prompt: PromptType,
901
902
903
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
904
        trace_headers: Optional[Mapping[str, str]] = None,
905
        priority: int = 0,
906
        tokenization_kwargs: Optional[dict[str, Any]] = None,
907
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
908
909
        raise NotImplementedError(
            "Pooling models are not supported in vLLM V0")
910

911
    async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
Antoni Baum's avatar
Antoni Baum committed
912
        """Abort a request.
913

Antoni Baum's avatar
Antoni Baum committed
914
915
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
916

Antoni Baum's avatar
Antoni Baum committed
917
918
919
        Args:
            request_id: The unique id of the request.
        """
920
921
922
        if not isinstance(request_id, str):
            raise RuntimeError("Only single-request abort supported in"
                               " deprecated V0")
923
924
925
926
927
928
929
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
930
        return self._abort(request_id)
931

Antoni Baum's avatar
Antoni Baum committed
932
    def _abort(self, request_id: str) -> None:
933
934
935
936
937
938
939
940
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
941
        self._request_tracker.abort_request(request_id,
942
                                            exception=asyncio.CancelledError,
943
                                            verbose=self.log_requests)
944

945
946
947
948
    async def get_vllm_config(self) -> VllmConfig:
        """Get the vllm configuration of the vLLM engine."""
        return self.engine.get_vllm_config()

949
950
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
951
        return self.engine.get_model_config()
952

953
954
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
955
        return self.engine.get_parallel_config()
956
957
958

    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
959
        return self.engine.get_scheduler_config()
960
961
962

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
963
        return self.engine.get_lora_config()
964

965
966
967
968
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
969
        self.engine.do_log_stats()
970

971
    async def check_health(self) -> None:
972
973
974
975
976
977
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

978
        await self.engine.check_health_async()
979
        logger.debug("Health check took %fs", time.perf_counter() - t)
980
981

    async def is_tracing_enabled(self) -> bool:
982
        return self.engine.is_tracing_enabled()
983
984

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
985
        self.engine.add_logger(logger_name=logger_name, logger=logger)
986
987

    def remove_logger(self, logger_name: str) -> None:
988
        self.engine.remove_logger(logger_name=logger_name)
989
990

    async def start_profile(self) -> None:
991
        self.engine.start_profile()
992
993

    async def stop_profile(self) -> None:
994
        self.engine.stop_profile()
995

996
997
998
    async def reset_mm_cache(self) -> None:
        self.engine.reset_mm_cache()

999
1000
1001
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        self.engine.reset_prefix_cache(device)
1002

1003
    async def sleep(self, level: int = 1) -> None:
1004
        await self.reset_prefix_cache()
1005
1006
        self.engine.sleep(level)

1007
1008
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
1009

1010
1011
1012
    async def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

1013
1014
    async def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.engine.add_lora(lora_request)
1015

1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine.collective_rpc_async(method, timeout, args,
                                                      kwargs)

1027
1028

# TODO(v1): Remove this class proxy when V1 goes default.
1029
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
1030
1031
1032
    from vllm.v1.engine.async_llm import AsyncLLM

    AsyncLLMEngine = AsyncLLM  # type: ignore