async_llm_engine.py 26.6 KB
Newer Older
1
import asyncio
2
import os
3
import time
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
6
                    Union, AsyncIterator, Callable)
7

8
from vllm.lora.request import LoRARequest
9
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
14
15
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
16
17

logger = init_logger(__name__)
18
19
ENGINE_ITERATION_TIMEOUT_S = int(
    os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
20

Antoni Baum's avatar
Antoni Baum committed
21

22
23
24
25
class AsyncEngineDeadError(RuntimeError):
    pass


26
27
28
def _raise_exception_on_finish(
        task: asyncio.Task, error_callback: Callable[[Exception],
                                                     None]) -> None:
29
30
    msg = ("Task finished unexpectedly. This should never happen! "
           "Please open an issue on Github.")
31
32

    exception = None
33
    try:
34
35
        task.result()
        # NOTE: This will be thrown if task exits normally (which it should not)
36
        raise AsyncEngineDeadError(msg)
37
38
39
40
41
42
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
            msg + " See stack trace above for the actual cause.") from e
43
44


Antoni Baum's avatar
Antoni Baum committed
45
46
47
48
49
50
51
52
53
class AsyncStream:
    """A stream of RequestOutputs for a request that can be
    iterated over asynchronously."""

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
        self._queue = asyncio.Queue()
        self._finished = False

54
    def put(self, item: Union[RequestOutput, Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
55
56
57
58
59
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
60
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
61
62
63
64
65
66
67
68
69
70
71
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

    async def __anext__(self) -> RequestOutput:
        result = await self._queue.get()
72
        if isinstance(result, Exception):
73
            raise result
Antoni Baum's avatar
Antoni Baum committed
74
75
76
        return result


77
78
79
80
81
82
83
84
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
85
        self.new_requests_event = asyncio.Event()
86
87
88
89

    def __contains__(self, item):
        return item in self._request_streams

90
91
    def __len__(self) -> int:
        return len(self._request_streams)
92
93
94
95
96
97
98
99

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
100
            self.abort_request(request_id)
101
        else:
102
            for rid, stream in self._request_streams.items():
103
                stream.put(exc)
104
                self.abort_request(rid)
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    def process_request_output(self,
                               request_output: RequestOutput,
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
                logger.info(f"Finished request {request_id}.")
            self.abort_request(request_id)

119
120
121
122
123
124
125
126
127
128
129
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
            logger.info(f"Finished request {request_id}.")
        self.abort_request(request_id)

130
131
132
133
134
135
136
137
138
139
140
141
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
142
143
144

        self.new_requests_event.set()

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
            logger.info(f"Aborted request {request_id}.")

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

161
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
162
163
        """Get the new requests and finished requests to be
        sent to the engine."""
164
        new_requests: List[Dict] = []
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
182

183
    async def wait_for_new_requests(self):
184
185
186
187
188
189
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
190

Antoni Baum's avatar
Antoni Baum committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

    async def step_async(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
205
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        if not scheduler_outputs.is_empty():
            # Execute the model.
            all_outputs = await self._run_workers_async(
                "execute_model",
                driver_kwargs={
                    "seq_group_metadata_list": seq_group_metadata_list,
                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
                })

            # Only the driver worker returns the sampling results.
            output = all_outputs[0]
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
222

223
        return self._process_model_outputs(output, scheduler_outputs)
Antoni Baum's avatar
Antoni Baum committed
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    async def encode_request_async(
        self,
        request_id: str,  # pylint: disable=unused-argument
        prompt: Optional[str],
        prompt_token_ids: Optional[List[int]] = None,
        lora_request: Optional[LoRARequest] = None,
    ):
        if prompt_token_ids is None:
            assert prompt is not None
            prompt_token_ids = await self.tokenizer.encode_async(
                request_id=request_id,
                prompt=prompt,
                lora_request=lora_request)
        return prompt_token_ids

    async def add_request_async(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
        prompt_token_ids = await self.encode_request_async(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request)

        return self.add_request(
            request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            sampling_params=sampling_params,
            arrival_time=arrival_time,
            lora_request=lora_request,
        )

Antoni Baum's avatar
Antoni Baum committed
269
270
271
272
    async def _run_workers_async(
        self,
        method: str,
        *args,
273
274
        driver_args: Optional[List[Any]] = None,
        driver_kwargs: Optional[Dict[str, Any]] = None,
Antoni Baum's avatar
Antoni Baum committed
275
276
277
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers."""
278
        coros = []
Antoni Baum's avatar
Antoni Baum committed
279

280
281
282
283
        if driver_args is None:
            driver_args = args
        if driver_kwargs is None:
            driver_kwargs = kwargs
Antoni Baum's avatar
Antoni Baum committed
284

285
286
287
288
        # Run the driver worker asynchronously.
        driver_executor = getattr(self.driver_worker, method)
        coros.append(asyncio.get_event_loop().run_in_executor(
            None, partial(driver_executor, *driver_args, **driver_kwargs)))
Antoni Baum's avatar
Antoni Baum committed
289

290
291
292
293
294
295
        # Run the ray workers asynchronously.
        for worker in self.workers:
            coros.append(worker.execute_method.remote(method, *args, **kwargs))

        all_outputs = await asyncio.gather(*coros)
        return all_outputs
296

297
298
299
300
    async def check_health_async(self):
        """Raises an error if engine is unhealthy."""
        self._check_if_any_actor_is_dead()

301

302
303
class AsyncLLMEngine:
    """An asynchronous wrapper for LLMEngine.
304

305
    This class is used to wrap the LLMEngine class to make it asynchronous. It
306
    uses asyncio to create a background loop that keeps processing incoming
307
    requests. The LLMEngine is kicked by the generate method when there
308
    are requests in the waiting queue. The generate method yields the outputs
309
    from the LLMEngine to the caller.
310

311
    NOTE: For the comprehensive list of arguments, see `LLMEngine`.
312
313
314
315
316

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
317
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
318
319
            async frontend will be executed in a separate process as the
            model workers.
320
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
321
322
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
323
324
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
325
326
        *args: Arguments for LLMEngine.
        *kwargs: Arguments for LLMEngine.
327
    """
328

Antoni Baum's avatar
Antoni Baum committed
329
330
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

331
332
333
334
335
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
336
                 max_log_len: Optional[int] = None,
337
                 start_engine_loop: bool = True,
338
                 **kwargs) -> None:
339
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
340
        self.engine_use_ray = engine_use_ray
341
        self.log_requests = log_requests
342
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
343
344
345
        self.engine = self._init_engine(*args, **kwargs)

        self.background_loop = None
346
347
348
349
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
        self._background_loop_unshielded = None
350
        self.start_engine_loop = start_engine_loop
351
352
        self._request_tracker: Optional[RequestTracker] = None
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
353

354
355
    @property
    def is_running(self) -> bool:
356
        return (self.background_loop is not None
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
        return self.errored or (self.background_loop is not None
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
374

375
376
377
    def get_tokenizer(self):
        return self.engine.tokenizer.tokenizer

378
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
379
        """Start the background loop."""
380
381
382
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
383
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
384
            raise RuntimeError("Background loop is already running.")
385
386
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
387
388
389
390

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
391
            partial(_raise_exception_on_finish,
392
                    error_callback=self._error_callback))
393
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
394
395
396

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
397
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
398
            engine_class = self._engine_class
399
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
400
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
401
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
402
403
404
405
406
407
408
409
410
411
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
            cache_config = args[1]
            parallel_config = args[2]
            if parallel_config.tensor_parallel_size == 1:
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
412
413
        return engine_class(*args, **kwargs)

414
415
416
417
    async def engine_step(self) -> bool:
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
418
419

        new_requests, finished_requests = (
420
            self._request_tracker.get_new_and_finished_requests())
421
422
423
424

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
425
426
427
428
429
430
431
432
433
434
435
436
            try:
                if self.engine_use_ray:
                    await self.engine.add_request.remote(**new_request)
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
437
438
439
440

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
441
442
        if self.engine_use_ray:
            request_outputs = await self.engine.step.remote()
443
        else:
Antoni Baum's avatar
Antoni Baum committed
444
            request_outputs = await self.engine.step_async()
445

Antoni Baum's avatar
Antoni Baum committed
446
        # Put the outputs into the corresponding streams.
447
        for request_output in request_outputs:
448
            self._request_tracker.process_request_output(
449
                request_output, verbose=self.log_requests)
Antoni Baum's avatar
Antoni Baum committed
450

451
452
        return len(request_outputs) > 0

Antoni Baum's avatar
Antoni Baum committed
453
454
455
456
457
458
459
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
            await self.engine.abort_request.remote(request_ids)
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
460
        has_requests_in_progress = False
Antoni Baum's avatar
Antoni Baum committed
461
        while True:
462
            if not has_requests_in_progress:
463
                logger.debug("Waiting for new requests...")
464
                await self._request_tracker.wait_for_new_requests()
465
466
467
468
469
470
471
472
473
474
475
476
                logger.debug("Got new requests!")

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
                has_requests_in_progress = await asyncio.wait_for(
                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
477
478
479
480
481
482
483
484
485
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
486
        lora_request: Optional[LoRARequest] = None,
Antoni Baum's avatar
Antoni Baum committed
487
488
    ) -> AsyncStream:
        if self.log_requests:
489
490
491
492
493
494
495
496
            shortened_prompt = prompt
            shortened_token_ids = prompt_token_ids
            if self.max_log_len is not None:
                if shortened_prompt is not None:
                    shortened_prompt = shortened_prompt[:self.max_log_len]
                if shortened_token_ids is not None:
                    shortened_token_ids = shortened_token_ids[:self.
                                                              max_log_len]
Antoni Baum's avatar
Antoni Baum committed
497
            logger.info(f"Received request {request_id}: "
498
                        f"prompt: {shortened_prompt!r}, "
zspo's avatar
zspo committed
499
500
                        f"sampling_params: {sampling_params}, "
                        f"prompt_token_ids: {shortened_token_ids}, "
501
                        f"lora_request: {lora_request}.")
Antoni Baum's avatar
Antoni Baum committed
502

503
        if not self.is_running:
504
505
506
507
508
509
510
511
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
512

513
514
        if arrival_time is None:
            arrival_time = time.time()
515
516
517
518
519
520
521
522
523
524
525
526
527

        if self.engine_use_ray:
            prompt_token_ids = await self.engine.encode_request_async.remote(
                request_id=request_id,
                prompt=prompt,
                prompt_token_ids=prompt_token_ids,
                lora_request=lora_request)
        else:
            prompt_token_ids = await self.engine.encode_request_async(
                request_id=request_id,
                prompt=prompt,
                prompt_token_ids=prompt_token_ids,
                lora_request=lora_request)
528

529
        stream = self._request_tracker.add_request(
530
531
532
533
            request_id,
            prompt=prompt,
            sampling_params=sampling_params,
            prompt_token_ids=prompt_token_ids,
534
            arrival_time=arrival_time,
535
            lora_request=lora_request)
Antoni Baum's avatar
Antoni Baum committed
536
537

        return stream
538

539
    async def generate(
540
541
542
543
        self,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        request_id: str,
544
        prompt_token_ids: Optional[List[int]] = None,
545
        lora_request: Optional[LoRARequest] = None,
546
    ) -> AsyncIterator[RequestOutput]:
547
548
549
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
550
551
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
552
553
554
555
556
557
558
559

        Args:
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
560
            lora_request: LoRA request to use for generation, if any.
561
562

        Yields:
563
            The output `RequestOutput` objects from the LLMEngine for the
564
            request.
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
608
        """
609
        # Preprocess the request.
610
611
        # This should not be used for logging, as it is monotonic time.
        arrival_time = time.monotonic()
612

Antoni Baum's avatar
Antoni Baum committed
613
        try:
614
615
616
617
618
619
620
621
            stream = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                prompt_token_ids=prompt_token_ids,
                arrival_time=arrival_time,
                lora_request=lora_request,
            )
622

Antoni Baum's avatar
Antoni Baum committed
623
624
            async for request_output in stream:
                yield request_output
625
626
627
        except (Exception, asyncio.CancelledError) as e:
            # If there is an exception or coroutine is cancelled, abort the
            # request.
Antoni Baum's avatar
Antoni Baum committed
628
629
            self._abort(request_id)
            raise e
630

Antoni Baum's avatar
Antoni Baum committed
631
632
    async def abort(self, request_id: str) -> None:
        """Abort a request.
633

Antoni Baum's avatar
Antoni Baum committed
634
635
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
636

Antoni Baum's avatar
Antoni Baum committed
637
638
639
        Args:
            request_id: The unique id of the request.
        """
640
641
642
643
644
645
646
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
647
        return self._abort(request_id)
648

Antoni Baum's avatar
Antoni Baum committed
649
    def _abort(self, request_id: str) -> None:
650
651
652
653
654
655
656
657
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
658
659
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
660

661
662
663
664
665
666
667
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_model_config.remote()
        else:
            return self.engine.get_model_config()

Zhuohan Li's avatar
Zhuohan Li committed
668
    @classmethod
669
    def from_engine_args(cls,
670
                         engine_args: AsyncEngineArgs,
671
                         start_engine_loop: bool = True) -> "AsyncLLMEngine":
Zhuohan Li's avatar
Zhuohan Li committed
672
673
674
675
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
Zhuohan Li's avatar
Zhuohan Li committed
676
        # Initialize the cluster.
677
678
        placement_group = initialize_cluster(parallel_config,
                                             engine_args.engine_use_ray)
Zhuohan Li's avatar
Zhuohan Li committed
679
        # Create the async LLM engine.
680
        engine = cls(parallel_config.worker_use_ray,
Zhuohan Li's avatar
Zhuohan Li committed
681
682
                     engine_args.engine_use_ray,
                     *engine_configs,
683
                     placement_group,
684
                     log_requests=not engine_args.disable_log_requests,
685
                     log_stats=not engine_args.disable_log_stats,
686
                     max_log_len=engine_args.max_log_len,
687
                     start_engine_loop=start_engine_loop)
Zhuohan Li's avatar
Zhuohan Li committed
688
        return engine
689
690
691
692
693
694

    async def do_log_stats(self) -> None:
        if self.engine_use_ray:
            await self.engine.do_log_stats.remote()
        else:
            self.engine.do_log_stats()
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710

    async def check_health(self):
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
                await self.engine.check_health.remote()
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
        logger.debug(f"Health check took {time.perf_counter()-t}s")