async_llm_engine.py 26.6 KB
Newer Older
1
import asyncio
2
import os
3
import time
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
7

8
9
from transformers import PreTrainedTokenizer

10
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
13
from vllm.engine.ray_utils import initialize_ray_cluster, ray
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.logger import init_logger
15
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
18
from vllm.sequence import MultiModalData
19
20

logger = init_logger(__name__)
21
22
ENGINE_ITERATION_TIMEOUT_S = int(
    os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
23

Antoni Baum's avatar
Antoni Baum committed
24

25
26
27
28
class AsyncEngineDeadError(RuntimeError):
    pass


29
30
31
def _raise_exception_on_finish(
        task: asyncio.Task, error_callback: Callable[[Exception],
                                                     None]) -> None:
32
33
    msg = ("Task finished unexpectedly. This should never happen! "
           "Please open an issue on Github.")
34
35

    exception = None
36
    try:
37
38
        task.result()
        # NOTE: This will be thrown if task exits normally (which it should not)
39
        raise AsyncEngineDeadError(msg)
40
41
42
43
44
45
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
            msg + " See stack trace above for the actual cause.") from e
46
47


Antoni Baum's avatar
Antoni Baum committed
48
49
50
51
52
53
54
55
56
class AsyncStream:
    """A stream of RequestOutputs for a request that can be
    iterated over asynchronously."""

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
        self._queue = asyncio.Queue()
        self._finished = False

57
    def put(self, item: Union[RequestOutput, Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
58
59
60
61
62
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
63
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
64
65
66
67
68
69
70
71
72
73
74
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

    async def __anext__(self) -> RequestOutput:
        result = await self._queue.get()
75
        if isinstance(result, Exception):
76
            raise result
Antoni Baum's avatar
Antoni Baum committed
77
78
79
        return result


80
81
82
83
84
85
86
87
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
88
        self.new_requests_event = asyncio.Event()
89
90
91
92

    def __contains__(self, item):
        return item in self._request_streams

93
94
    def __len__(self) -> int:
        return len(self._request_streams)
95
96
97
98
99
100
101
102

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
103
            self.abort_request(request_id)
104
        else:
105
            for rid, stream in self._request_streams.items():
106
                stream.put(exc)
107
                self.abort_request(rid)
108
109
110
111
112
113
114
115
116
117
118
119
120
121

    def process_request_output(self,
                               request_output: RequestOutput,
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
                logger.info(f"Finished request {request_id}.")
            self.abort_request(request_id)

122
123
124
125
126
127
128
129
130
131
132
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
            logger.info(f"Finished request {request_id}.")
        self.abort_request(request_id)

133
134
135
136
137
138
139
140
141
142
143
144
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
145
146
147

        self.new_requests_event.set()

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
            logger.info(f"Aborted request {request_id}.")

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

164
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
165
166
        """Get the new requests and finished requests to be
        sent to the engine."""
167
        new_requests: List[Dict] = []
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
185

186
    async def wait_for_new_requests(self):
187
188
189
190
191
192
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
193

Antoni Baum's avatar
Antoni Baum committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

    async def step_async(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
208
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
209

210
211
        if not scheduler_outputs.is_empty():
            # Execute the model.
212
213
214
215
            output = await self.model_executor.execute_model_async(
                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
                scheduler_outputs.blocks_to_swap_out,
                scheduler_outputs.blocks_to_copy)
216
217
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
218

219
        return self._process_model_outputs(output, scheduler_outputs)
Antoni Baum's avatar
Antoni Baum committed
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    async def encode_request_async(
        self,
        request_id: str,  # pylint: disable=unused-argument
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]] = None,
        lora_request: Optional[LoRARequest] = None,
    ):
        if prompt_token_ids is None:
            assert prompt is not None
            prompt_token_ids = await self.tokenizer.encode_async(
                request_id=request_id,
                prompt=prompt,
                lora_request=lora_request)
        return prompt_token_ids

    async def add_request_async(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
244
        multi_modal_data: Optional[MultiModalData] = None,
245
246
247
248
249
250
251
252
253
254
255
256
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
        prompt_token_ids = await self.encode_request_async(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request)

257
258
259
260
261
262
263
        return self.add_request(request_id,
                                prompt=prompt,
                                prompt_token_ids=prompt_token_ids,
                                sampling_params=sampling_params,
                                arrival_time=arrival_time,
                                lora_request=lora_request,
                                multi_modal_data=multi_modal_data)
264

265
266
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
267

268

269
270
class AsyncLLMEngine:
    """An asynchronous wrapper for LLMEngine.
271

272
    This class is used to wrap the LLMEngine class to make it asynchronous. It
273
    uses asyncio to create a background loop that keeps processing incoming
274
    requests. The LLMEngine is kicked by the generate method when there
275
    are requests in the waiting queue. The generate method yields the outputs
276
    from the LLMEngine to the caller.
277

278
    NOTE: For the comprehensive list of arguments, see `LLMEngine`.
279
280
281
282
283

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
284
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
285
286
            async frontend will be executed in a separate process as the
            model workers.
287
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
288
289
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
290
291
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
292
293
        *args: Arguments for LLMEngine.
        *kwargs: Arguments for LLMEngine.
294
    """
295

Antoni Baum's avatar
Antoni Baum committed
296
297
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

298
299
300
301
302
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
303
                 max_log_len: Optional[int] = None,
304
                 start_engine_loop: bool = True,
305
                 **kwargs) -> None:
306
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
307
        self.engine_use_ray = engine_use_ray
308
        self.log_requests = log_requests
309
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
310
311
312
        self.engine = self._init_engine(*args, **kwargs)

        self.background_loop = None
313
314
315
316
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
        self._background_loop_unshielded = None
317
        self.start_engine_loop = start_engine_loop
318
319
        self._request_tracker: Optional[RequestTracker] = None
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
320

321
322
323
324
325
326
327
328
    @classmethod
    def from_engine_args(cls,
                         engine_args: AsyncEngineArgs,
                         start_engine_loop: bool = True) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
329
330
331
332
333
334
        device_config = engine_configs[4]

        if device_config.device_type == "neuron":
            raise NotImplementedError("Neuron is not supported for "
                                      "async engine yet.")
        elif parallel_config.worker_use_ray or engine_args.engine_use_ray:
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
            initialize_ray_cluster(parallel_config)
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
        else:
            assert parallel_config.world_size == 1, (
                "Ray is required if parallel_config.world_size > 1.")
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
        # Create the async LLM engine.
        engine = cls(parallel_config.worker_use_ray,
                     engine_args.engine_use_ray,
                     *engine_configs,
                     executor_class,
                     log_requests=not engine_args.disable_log_requests,
                     log_stats=not engine_args.disable_log_stats,
                     max_log_len=engine_args.max_log_len,
                     start_engine_loop=start_engine_loop)
        return engine

354
355
    @property
    def is_running(self) -> bool:
356
        return (self.background_loop is not None
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
        return self.errored or (self.background_loop is not None
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
374

375
376
377
378
379
    async def get_tokenizer(self) -> "PreTrainedTokenizer":
        if self.engine_use_ray:
            return await self.engine.get_tokenizer.remote()
        else:
            return self.engine.get_tokenizer()
380

381
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
382
        """Start the background loop."""
383
384
385
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
386
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
387
            raise RuntimeError("Background loop is already running.")
388
389
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
390
391
392
393

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
394
            partial(_raise_exception_on_finish,
395
                    error_callback=self._error_callback))
396
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
397
398
399

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
400
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
401
            engine_class = self._engine_class
402
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
403
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
404
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
407
408
409
410
411
412
413
414
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
            cache_config = args[1]
            parallel_config = args[2]
            if parallel_config.tensor_parallel_size == 1:
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
415
416
        return engine_class(*args, **kwargs)

417
418
419
420
    async def engine_step(self) -> bool:
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
421
422

        new_requests, finished_requests = (
423
            self._request_tracker.get_new_and_finished_requests())
424
425
426
427

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
428
429
430
431
432
433
434
435
436
437
438
439
            try:
                if self.engine_use_ray:
                    await self.engine.add_request.remote(**new_request)
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
440
441
442
443

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
444
445
        if self.engine_use_ray:
            request_outputs = await self.engine.step.remote()
446
        else:
Antoni Baum's avatar
Antoni Baum committed
447
            request_outputs = await self.engine.step_async()
448

Antoni Baum's avatar
Antoni Baum committed
449
        # Put the outputs into the corresponding streams.
450
        for request_output in request_outputs:
451
            self._request_tracker.process_request_output(
452
                request_output, verbose=self.log_requests)
Antoni Baum's avatar
Antoni Baum committed
453

454
455
        return len(request_outputs) > 0

Antoni Baum's avatar
Antoni Baum committed
456
457
458
459
460
461
462
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
            await self.engine.abort_request.remote(request_ids)
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
463
        has_requests_in_progress = False
Antoni Baum's avatar
Antoni Baum committed
464
        while True:
465
            if not has_requests_in_progress:
466
                logger.debug("Waiting for new requests...")
467
                await self._request_tracker.wait_for_new_requests()
468
469
470
471
472
473
474
475
476
477
478
479
                logger.debug("Got new requests!")

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
                has_requests_in_progress = await asyncio.wait_for(
                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
480
481
482
483
484
485
486
487
488
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
489
        lora_request: Optional[LoRARequest] = None,
490
        multi_modal_data: Optional[MultiModalData] = None,
Antoni Baum's avatar
Antoni Baum committed
491
492
    ) -> AsyncStream:
        if self.log_requests:
493
494
495
496
497
498
499
500
            shortened_prompt = prompt
            shortened_token_ids = prompt_token_ids
            if self.max_log_len is not None:
                if shortened_prompt is not None:
                    shortened_prompt = shortened_prompt[:self.max_log_len]
                if shortened_token_ids is not None:
                    shortened_token_ids = shortened_token_ids[:self.
                                                              max_log_len]
Antoni Baum's avatar
Antoni Baum committed
501
            logger.info(f"Received request {request_id}: "
502
                        f"prompt: {shortened_prompt!r}, "
zspo's avatar
zspo committed
503
504
                        f"sampling_params: {sampling_params}, "
                        f"prompt_token_ids: {shortened_token_ids}, "
505
                        f"lora_request: {lora_request}.")
Antoni Baum's avatar
Antoni Baum committed
506

507
        if not self.is_running:
508
509
510
511
512
513
514
515
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
516

517
518
        if arrival_time is None:
            arrival_time = time.time()
519
520
521
522
523
524
525
526
527
528
529
530
531

        if self.engine_use_ray:
            prompt_token_ids = await self.engine.encode_request_async.remote(
                request_id=request_id,
                prompt=prompt,
                prompt_token_ids=prompt_token_ids,
                lora_request=lora_request)
        else:
            prompt_token_ids = await self.engine.encode_request_async(
                request_id=request_id,
                prompt=prompt,
                prompt_token_ids=prompt_token_ids,
                lora_request=lora_request)
532

533
        stream = self._request_tracker.add_request(
534
535
536
537
            request_id,
            prompt=prompt,
            sampling_params=sampling_params,
            prompt_token_ids=prompt_token_ids,
538
            arrival_time=arrival_time,
539
540
541
            lora_request=lora_request,
            multi_modal_data=multi_modal_data,
        )
Antoni Baum's avatar
Antoni Baum committed
542
543

        return stream
544

545
    async def generate(
546
547
548
549
        self,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        request_id: str,
550
        prompt_token_ids: Optional[List[int]] = None,
551
        lora_request: Optional[LoRARequest] = None,
552
        multi_modal_data: Optional[MultiModalData] = None
553
    ) -> AsyncIterator[RequestOutput]:
554
555
556
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
557
558
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
559
560
561
562
563
564
565
566

        Args:
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
567
            lora_request: LoRA request to use for generation, if any.
568
            multi_modal_data: Multi modal data per request.
569
570

        Yields:
571
            The output `RequestOutput` objects from the LLMEngine for the
572
            request.
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
616
        """
617
        # Preprocess the request.
618
        arrival_time = time.time()
619

Antoni Baum's avatar
Antoni Baum committed
620
        try:
621
622
623
624
625
626
627
            stream = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                prompt_token_ids=prompt_token_ids,
                arrival_time=arrival_time,
                lora_request=lora_request,
628
                multi_modal_data=multi_modal_data,
629
            )
630

Antoni Baum's avatar
Antoni Baum committed
631
632
            async for request_output in stream:
                yield request_output
633
634
635
        except (Exception, asyncio.CancelledError) as e:
            # If there is an exception or coroutine is cancelled, abort the
            # request.
Antoni Baum's avatar
Antoni Baum committed
636
637
            self._abort(request_id)
            raise e
638

Antoni Baum's avatar
Antoni Baum committed
639
640
    async def abort(self, request_id: str) -> None:
        """Abort a request.
641

Antoni Baum's avatar
Antoni Baum committed
642
643
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
644

Antoni Baum's avatar
Antoni Baum committed
645
646
647
        Args:
            request_id: The unique id of the request.
        """
648
649
650
651
652
653
654
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
655
        return self._abort(request_id)
656

Antoni Baum's avatar
Antoni Baum committed
657
    def _abort(self, request_id: str) -> None:
658
659
660
661
662
663
664
665
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
666
667
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
668

669
670
671
672
673
674
675
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_model_config.remote()
        else:
            return self.engine.get_model_config()

676
677
678
679
680
    async def do_log_stats(self) -> None:
        if self.engine_use_ray:
            await self.engine.do_log_stats.remote()
        else:
            self.engine.do_log_stats()
681

682
    async def check_health(self) -> None:
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
                await self.engine.check_health.remote()
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
        logger.debug(f"Health check took {time.perf_counter()-t}s")