async_llm_engine.py 26.7 KB
Newer Older
1
import asyncio
2
import os
3
import time
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
7

8
9
from transformers import PreTrainedTokenizer

10
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
13
from vllm.engine.ray_utils import initialize_ray_cluster, ray
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.logger import init_logger
15
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
18
from vllm.sequence import MultiModalData
yhu422's avatar
yhu422 committed
19
from vllm.usage.usage_lib import UsageContext
20
21

logger = init_logger(__name__)
22
23
ENGINE_ITERATION_TIMEOUT_S = int(
    os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
24

Antoni Baum's avatar
Antoni Baum committed
25

26
27
28
29
class AsyncEngineDeadError(RuntimeError):
    pass


30
31
32
def _raise_exception_on_finish(
        task: asyncio.Task, error_callback: Callable[[Exception],
                                                     None]) -> None:
33
34
    msg = ("Task finished unexpectedly. This should never happen! "
           "Please open an issue on Github.")
35
36

    exception = None
37
    try:
38
39
        task.result()
        # NOTE: This will be thrown if task exits normally (which it should not)
40
        raise AsyncEngineDeadError(msg)
41
42
43
44
45
46
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
            msg + " See stack trace above for the actual cause.") from e
47
48


Antoni Baum's avatar
Antoni Baum committed
49
50
51
52
53
54
55
56
57
class AsyncStream:
    """A stream of RequestOutputs for a request that can be
    iterated over asynchronously."""

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
        self._queue = asyncio.Queue()
        self._finished = False

58
    def put(self, item: Union[RequestOutput, Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
59
60
61
62
63
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
64
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
65
66
67
68
69
70
71
72
73
74
75
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

    async def __anext__(self) -> RequestOutput:
        result = await self._queue.get()
76
        if isinstance(result, Exception):
77
            raise result
Antoni Baum's avatar
Antoni Baum committed
78
79
80
        return result


81
82
83
84
85
86
87
88
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
89
        self.new_requests_event = asyncio.Event()
90
91
92
93

    def __contains__(self, item):
        return item in self._request_streams

94
95
    def __len__(self) -> int:
        return len(self._request_streams)
96
97
98
99
100
101
102
103

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
104
            self.abort_request(request_id)
105
        else:
106
            for rid, stream in self._request_streams.items():
107
                stream.put(exc)
108
                self.abort_request(rid)
109
110
111
112
113
114
115
116
117
118
119
120
121
122

    def process_request_output(self,
                               request_output: RequestOutput,
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
                logger.info(f"Finished request {request_id}.")
            self.abort_request(request_id)

123
124
125
126
127
128
129
130
131
132
133
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
            logger.info(f"Finished request {request_id}.")
        self.abort_request(request_id)

134
135
136
137
138
139
140
141
142
143
144
145
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
146
147
148

        self.new_requests_event.set()

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
            logger.info(f"Aborted request {request_id}.")

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

165
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
166
167
        """Get the new requests and finished requests to be
        sent to the engine."""
168
        new_requests: List[Dict] = []
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
186

187
    async def wait_for_new_requests(self):
188
189
190
191
192
193
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
194

Antoni Baum's avatar
Antoni Baum committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

    async def step_async(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
209
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
210

211
212
        if not scheduler_outputs.is_empty():
            # Execute the model.
213
214
215
216
            output = await self.model_executor.execute_model_async(
                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
                scheduler_outputs.blocks_to_swap_out,
                scheduler_outputs.blocks_to_copy)
217
218
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
219

220
221
222
        return self._process_model_outputs(
            output, scheduler_outputs.scheduled_seq_groups,
            scheduler_outputs.ignored_seq_groups)
Antoni Baum's avatar
Antoni Baum committed
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    async def encode_request_async(
        self,
        request_id: str,  # pylint: disable=unused-argument
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]] = None,
        lora_request: Optional[LoRARequest] = None,
    ):
        if prompt_token_ids is None:
            assert prompt is not None
            prompt_token_ids = await self.tokenizer.encode_async(
                request_id=request_id,
                prompt=prompt,
                lora_request=lora_request)
        return prompt_token_ids

    async def add_request_async(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
247
        multi_modal_data: Optional[MultiModalData] = None,
248
249
250
251
252
253
254
255
256
257
258
259
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
        prompt_token_ids = await self.encode_request_async(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request)

260
261
262
263
264
265
266
        return self.add_request(request_id,
                                prompt=prompt,
                                prompt_token_ids=prompt_token_ids,
                                sampling_params=sampling_params,
                                arrival_time=arrival_time,
                                lora_request=lora_request,
                                multi_modal_data=multi_modal_data)
267

268
269
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
270

271

272
273
class AsyncLLMEngine:
    """An asynchronous wrapper for LLMEngine.
274

275
    This class is used to wrap the LLMEngine class to make it asynchronous. It
276
    uses asyncio to create a background loop that keeps processing incoming
277
    requests. The LLMEngine is kicked by the generate method when there
278
    are requests in the waiting queue. The generate method yields the outputs
279
    from the LLMEngine to the caller.
280

281
    NOTE: For the comprehensive list of arguments, see `LLMEngine`.
282
283
284
285
286

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
287
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
288
289
            async frontend will be executed in a separate process as the
            model workers.
290
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
291
292
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
293
294
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
295
296
        *args: Arguments for LLMEngine.
        *kwargs: Arguments for LLMEngine.
297
    """
298

Antoni Baum's avatar
Antoni Baum committed
299
300
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

301
302
303
304
305
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
306
                 max_log_len: Optional[int] = None,
307
                 start_engine_loop: bool = True,
308
                 **kwargs) -> None:
309
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
310
        self.engine_use_ray = engine_use_ray
311
        self.log_requests = log_requests
312
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
313
314
315
        self.engine = self._init_engine(*args, **kwargs)

        self.background_loop = None
316
317
318
319
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
        self._background_loop_unshielded = None
320
        self.start_engine_loop = start_engine_loop
321
322
        self._request_tracker: Optional[RequestTracker] = None
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
323

324
    @classmethod
yhu422's avatar
yhu422 committed
325
326
327
328
329
330
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "AsyncLLMEngine":
331
332
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
333
        engine_config = engine_args.create_engine_config()
334

335
        if engine_config.device_config.device_type == "neuron":
336
337
            raise NotImplementedError("Neuron is not supported for "
                                      "async engine yet.")
338
        elif engine_config.parallel_config.worker_use_ray:
339
            initialize_ray_cluster(engine_config.parallel_config)
340
341
342
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
        else:
343
            assert engine_config.parallel_config.world_size == 1, (
344
345
346
347
                "Ray is required if parallel_config.world_size > 1.")
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
348
        engine = cls(
349
            engine_config.parallel_config.worker_use_ray,
yhu422's avatar
yhu422 committed
350
            engine_args.engine_use_ray,
351
352
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
353
354
355
356
357
358
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            max_log_len=engine_args.max_log_len,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
        )
359
360
        return engine

361
362
    @property
    def is_running(self) -> bool:
363
        return (self.background_loop is not None
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
        return self.errored or (self.background_loop is not None
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
381

382
383
384
385
386
    async def get_tokenizer(self) -> "PreTrainedTokenizer":
        if self.engine_use_ray:
            return await self.engine.get_tokenizer.remote()
        else:
            return self.engine.get_tokenizer()
387

388
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
389
        """Start the background loop."""
390
391
392
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
393
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
394
            raise RuntimeError("Background loop is already running.")
395
396
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
397
398
399
400

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
401
            partial(_raise_exception_on_finish,
402
                    error_callback=self._error_callback))
403
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
404
405
406

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
407
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
408
            engine_class = self._engine_class
409
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
410
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
411
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
414
415
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
Woosuk Kwon's avatar
Woosuk Kwon committed
416
417
418
419
420
421
            if parallel_config.tensor_parallel_size == 1:
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
422
423
        return engine_class(*args, **kwargs)

424
425
426
427
    async def engine_step(self) -> bool:
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
428
429

        new_requests, finished_requests = (
430
            self._request_tracker.get_new_and_finished_requests())
431
432
433
434

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
435
436
437
438
439
440
441
442
443
444
445
446
            try:
                if self.engine_use_ray:
                    await self.engine.add_request.remote(**new_request)
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
447
448
449
450

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
451
452
        if self.engine_use_ray:
            request_outputs = await self.engine.step.remote()
453
        else:
Antoni Baum's avatar
Antoni Baum committed
454
            request_outputs = await self.engine.step_async()
455

Antoni Baum's avatar
Antoni Baum committed
456
        # Put the outputs into the corresponding streams.
457
        for request_output in request_outputs:
458
            self._request_tracker.process_request_output(
459
                request_output, verbose=self.log_requests)
Antoni Baum's avatar
Antoni Baum committed
460

461
462
        return len(request_outputs) > 0

Antoni Baum's avatar
Antoni Baum committed
463
464
465
466
467
468
469
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
            await self.engine.abort_request.remote(request_ids)
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
470
        has_requests_in_progress = False
Antoni Baum's avatar
Antoni Baum committed
471
        while True:
472
            if not has_requests_in_progress:
473
                logger.debug("Waiting for new requests...")
474
                await self._request_tracker.wait_for_new_requests()
475
476
477
478
479
480
481
482
483
484
485
486
                logger.debug("Got new requests!")

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
                has_requests_in_progress = await asyncio.wait_for(
                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
487
488
489
490
491
492
493
494
495
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
496
        lora_request: Optional[LoRARequest] = None,
497
        multi_modal_data: Optional[MultiModalData] = None,
Antoni Baum's avatar
Antoni Baum committed
498
499
    ) -> AsyncStream:
        if self.log_requests:
500
501
502
503
504
505
506
507
            shortened_prompt = prompt
            shortened_token_ids = prompt_token_ids
            if self.max_log_len is not None:
                if shortened_prompt is not None:
                    shortened_prompt = shortened_prompt[:self.max_log_len]
                if shortened_token_ids is not None:
                    shortened_token_ids = shortened_token_ids[:self.
                                                              max_log_len]
Antoni Baum's avatar
Antoni Baum committed
508
            logger.info(f"Received request {request_id}: "
509
                        f"prompt: {shortened_prompt!r}, "
zspo's avatar
zspo committed
510
511
                        f"sampling_params: {sampling_params}, "
                        f"prompt_token_ids: {shortened_token_ids}, "
512
                        f"lora_request: {lora_request}.")
Antoni Baum's avatar
Antoni Baum committed
513

514
        if not self.is_running:
515
516
517
518
519
520
521
522
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
523

524
525
        if arrival_time is None:
            arrival_time = time.time()
526
527
528
529
530
531
532
533
534
535
536
537
538

        if self.engine_use_ray:
            prompt_token_ids = await self.engine.encode_request_async.remote(
                request_id=request_id,
                prompt=prompt,
                prompt_token_ids=prompt_token_ids,
                lora_request=lora_request)
        else:
            prompt_token_ids = await self.engine.encode_request_async(
                request_id=request_id,
                prompt=prompt,
                prompt_token_ids=prompt_token_ids,
                lora_request=lora_request)
539

540
        stream = self._request_tracker.add_request(
541
542
543
544
            request_id,
            prompt=prompt,
            sampling_params=sampling_params,
            prompt_token_ids=prompt_token_ids,
545
            arrival_time=arrival_time,
546
547
548
            lora_request=lora_request,
            multi_modal_data=multi_modal_data,
        )
Antoni Baum's avatar
Antoni Baum committed
549
550

        return stream
551

552
    async def generate(
553
554
555
556
        self,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        request_id: str,
557
        prompt_token_ids: Optional[List[int]] = None,
558
        lora_request: Optional[LoRARequest] = None,
559
        multi_modal_data: Optional[MultiModalData] = None
560
    ) -> AsyncIterator[RequestOutput]:
561
562
563
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
564
565
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
566
567
568
569
570
571
572
573

        Args:
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
574
            lora_request: LoRA request to use for generation, if any.
575
            multi_modal_data: Multi modal data per request.
576
577

        Yields:
578
            The output `RequestOutput` objects from the LLMEngine for the
579
            request.
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
623
        """
624
        # Preprocess the request.
625
        arrival_time = time.time()
626

Antoni Baum's avatar
Antoni Baum committed
627
        try:
628
629
630
631
632
633
634
            stream = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                prompt_token_ids=prompt_token_ids,
                arrival_time=arrival_time,
                lora_request=lora_request,
635
                multi_modal_data=multi_modal_data,
636
            )
637

Antoni Baum's avatar
Antoni Baum committed
638
639
            async for request_output in stream:
                yield request_output
640
641
642
        except (Exception, asyncio.CancelledError) as e:
            # If there is an exception or coroutine is cancelled, abort the
            # request.
Antoni Baum's avatar
Antoni Baum committed
643
644
            self._abort(request_id)
            raise e
645

Antoni Baum's avatar
Antoni Baum committed
646
647
    async def abort(self, request_id: str) -> None:
        """Abort a request.
648

Antoni Baum's avatar
Antoni Baum committed
649
650
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
651

Antoni Baum's avatar
Antoni Baum committed
652
653
654
        Args:
            request_id: The unique id of the request.
        """
655
656
657
658
659
660
661
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
662
        return self._abort(request_id)
663

Antoni Baum's avatar
Antoni Baum committed
664
    def _abort(self, request_id: str) -> None:
665
666
667
668
669
670
671
672
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
673
674
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
675

676
677
678
679
680
681
682
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_model_config.remote()
        else:
            return self.engine.get_model_config()

683
684
685
686
687
    async def do_log_stats(self) -> None:
        if self.engine_use_ray:
            await self.engine.do_log_stats.remote()
        else:
            self.engine.do_log_stats()
688

689
    async def check_health(self) -> None:
690
691
692
693
694
695
696
697
698
699
700
701
702
703
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
                await self.engine.check_health.remote()
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
        logger.debug(f"Health check took {time.perf_counter()-t}s")