async_llm_engine.py 45.8 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
5
                    Optional, Set, Tuple, Type, Union)
6

7
from transformers import PreTrainedTokenizer
8
from typing_extensions import assert_never
9

10
import vllm.envs as envs
11
12
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
13
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.engine.arg_utils import AsyncEngineArgs
15
from vllm.engine.async_timeout import asyncio_timeout
16
17
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
                                    PromptComponents)
18
from vllm.engine.metrics import StatLoggerBase
19
from vllm.executor.executor_base import ExecutorAsyncBase
20
from vllm.executor.ray_utils import initialize_ray_cluster, ray
21
22
23
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
                         SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.logger import init_logger
25
from vllm.lora.request import LoRARequest
26
27
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
28
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.sampling_params import SamplingParams
30
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
31
from vllm.usage.usage_lib import UsageContext
32
from vllm.utils import print_warning_once
33
34

logger = init_logger(__name__)
35
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
36

Antoni Baum's avatar
Antoni Baum committed
37

38
39
40
41
class AsyncEngineDeadError(RuntimeError):
    pass


42
43
44
45
46
47
48
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
49
50

    exception = None
51
    try:
52
53
54
55
56
57
58
59
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
60
61
62
63
64
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
65
            "Task finished unexpectedly. This should never happen! "
66
            "Please open an issue on Github. See stack trace above for the "
67
            "actual cause.") from e
68
69


70
71
72
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
73
class AsyncStream:
74
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
75
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
76

77
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
78
        self.request_id = request_id
79
        self._cancel = cancel
80
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
81
82
        self._finished = False

83
84
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
85
86
87
88
        if self._finished:
            return
        self._queue.put_nowait(item)

89
90
91
92
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
93
94
95
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
96
                exception if exception is not None else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
97
98
99
100
101

    @property
    def finished(self) -> bool:
        return self._finished

102
103
104
105
106
107
108
109
110
111
112
113
114
115
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
            while not self._finished:
                result = await self._queue.get()
                if isinstance(result, Exception):
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
116
117


118
119
120
121
122
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
123
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
124
125
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
126
        self.new_requests_event = asyncio.Event()
127
128
129
130

    def __contains__(self, item):
        return item in self._request_streams

131
132
    def __len__(self) -> int:
        return len(self._request_streams)
133
134
135
136
137
138
139

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
140
            self.abort_request(request_id, exception=exc)
141
        else:
142
            # NB: tuple() used here because self.abort_request pops the stream
143
            # out of self._request_streams, so we can't iterate on it directly
144
145
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
146
147

    def process_request_output(self,
148
149
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
150
151
152
153
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
154
        finished = request_output.finished
155

156
157
158
159
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
160
161
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
162
        if stream is not None:
163
            stream.put(request_output)
164
165
166
167
168
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
169

170
171
    def process_exception(self,
                          request_id: str,
172
                          exception: BaseException,
173
174
175
176
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
177
            logger.info("Finished request %s.", request_id)
178
        self.abort_request(request_id, exception=exception)
179

180
181
182
183
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
184
185
186
187
188
189
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

190
191
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
192
193
194
195
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
196
197
198

        self.new_requests_event.set()

199
200
201
        if verbose:
            logger.info("Added request %s.", request_id)

202
203
        return stream

204
205
206
    def abort_request(self,
                      request_id: str,
                      *,
207
208
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
209
                      verbose: bool = False) -> None:
210
211
        """Abort a request during next background loop iteration."""
        if verbose:
212
            logger.info("Aborted request %s.", request_id)
213

214
        self._aborted_requests.put_nowait(request_id)
215

216
217
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
218
            stream.finish(exception=exception)
219

220
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
221
222
        """Get the new requests and finished requests to be
        sent to the engine."""
223
        new_requests: List[Dict] = []
224
225
        finished_requests: Set[str] = set()

226
227
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
228
229
230
231
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
232
233
            request_id = stream.request_id
            if request_id in finished_requests:
234
                # The request has already been aborted.
235
236
237
238
239
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
240
241

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
242

243
    async def wait_for_new_requests(self):
244
245
246
247
248
249
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
250

Antoni Baum's avatar
Antoni Baum committed
251
252
253
254

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

255
    async def step_async(
256
257
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
258
259
260
261
262
263
264
265
266
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
267
268
        seq_group_metadata_list, scheduler_outputs = self.scheduler[
            virtual_engine].schedule()
Antoni Baum's avatar
Antoni Baum committed
269

270
271
        if not scheduler_outputs.is_empty():
            # Execute the model.
272
273
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
274
275
276
277
278
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
279
                virtual_engine=virtual_engine,
280
281
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
Mor Zusman's avatar
Mor Zusman committed
282
                finished_requests_ids=finished_requests_ids)
283
            output = await self.model_executor.execute_model_async(
284
                execute_model_req)
285
286
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
287

288
        request_outputs = self._process_model_outputs(
289
            output, scheduler_outputs.scheduled_seq_groups,
290
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
291

292
        # Log stats.
293
        self.do_log_stats(scheduler_outputs, output)
294

295
296
297
        # Tracing
        self.do_tracing(scheduler_outputs)

298
299
        return request_outputs

300
301
302
303
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    async def _tokenize_prompt_async(
        self,
        prompt: str,
        request_id: str,
        lora_request: Optional[LoRARequest],
    ) -> List[int]:
        """Async version of :meth:`_tokenize_prompt`."""
        tokenizer = self.get_tokenizer_group("prompts must be None if "
                                             "skip_tokenizer_init is True")

        return await tokenizer.encode_async(request_id=request_id,
                                            prompt=prompt,
                                            lora_request=lora_request)

    async def _extract_prompt_components_async(
319
        self,
320
        inputs: SingletonPromptInputs,
321
        request_id: str,
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        lora_request: Optional[LoRARequest] = None,
    ) -> PromptComponents:
        """Async version of :meth:`_extract_prompt_components`."""
        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = await self._tokenize_prompt_async(
                prompt,
                request_id=request_id,
                lora_request=lora_request,
            )
            multi_modal_data = None
        elif isinstance(inputs, dict):
            if "prompt_token_ids" in inputs:
                prompt = None
                prompt_token_ids = inputs["prompt_token_ids"]
            else:
                # NOTE: This extra assignment is required to pass mypy
                prompt = parsed_prompt = inputs["prompt"]
                prompt_token_ids = await self._tokenize_prompt_async(
                    parsed_prompt,
                    request_id=request_id,
                    lora_request=lora_request,
                )

            multi_modal_data = inputs.get("multi_modal_data")
        else:
            assert_never(inputs)

        return prompt, prompt_token_ids, multi_modal_data

    async def _process_encoder_decoder_prompt_async(
        self,
354
        inputs: PromptInputs,
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        request_id: str,
    ) -> EncoderDecoderLLMInputs:
        """Async version of :meth:`_process_encoder_decoder_prompt`."""
        encoder_comps: PromptComponents
        decoder_comps: DecoderPromptComponents

        if is_explicit_encoder_decoder_prompt(inputs):
            encoder_task = self._extract_prompt_components_async(
                inputs["encoder_prompt"],
                request_id=request_id,
            )

            if (decoder_input := inputs["decoder_prompt"]) is None:
                encoder_comps = await encoder_task
                decoder_comps = None, None, None
            else:
                decoder_task = self._extract_prompt_components_async(
                    decoder_input,
                    request_id=request_id,
                )

                encoder_comps, decoder_comps = await asyncio.gather(
                    encoder_task, decoder_task)
        else:
            encoder_comps = await self._extract_prompt_components_async(
                inputs,
                request_id=request_id,
            )

            decoder_comps = None, None, None

        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

    async def _process_decoder_only_prompt_async(
        self,
        inputs: SingletonPromptInputs,
        request_id: str,
392
        lora_request: Optional[LoRARequest] = None,
393
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
394
    ) -> LLMInputs:
395
396
397
398
399
400
        """Async version of :meth:`_process_decoder_only_prompt`."""
        prompt_comps = await self._extract_prompt_components_async(
            inputs,
            request_id=request_id,
            lora_request=lora_request,
        )
401

402
403
404
405
        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )
406

407
408
409
410
411
412
413
414
415
416
417
418
419
    async def process_model_inputs_async(
        self,
        inputs: PromptInputs,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
        """Async version of :meth:`process_model_inputs`."""
        if self.is_encoder_decoder_model():
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            model_inputs = await self._process_encoder_decoder_prompt_async(
                inputs,
420
                request_id=request_id,
421
            )
422
        else:
423
424
425
            if is_explicit_encoder_decoder_prompt(inputs):
                raise ValueError("Cannot pass encoder-decoder prompt "
                                 "to decoder-only models")
426

427
428
429
430
431
432
433
            # Decoder-only operation
            model_inputs = await self._process_decoder_only_prompt_async(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )
434

435
        return self.input_processor(model_inputs)
436
437

    async def add_request_async(
438
439
440
441
442
443
444
445
        self,
        request_id: str,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
446
    ) -> None:
447
        """Async version of :meth:`add_request`."""
448
449
450
451
452
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
453
454

        processed_inputs = await self.process_model_inputs_async(
455
            inputs,
456
457
            request_id=request_id,
            lora_request=lora_request,
458
459
            prompt_adapter_request=prompt_adapter_request,
        )
460
461

        self._add_processed_request(
462
            request_id=request_id,
463
464
465
466
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
467
            prompt_adapter_request=prompt_adapter_request,
468
            trace_headers=trace_headers,
469
        )
470

471
    async def check_health_async(self) -> None:
472
473
        if self.tokenizer:
            self.tokenizer.check_health()
474
        self.model_executor.check_health()
475

476

477
class AsyncLLMEngine:
478
    """An asynchronous wrapper for :class:`LLMEngine`.
479

480
481
482
483
484
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
485
486
487
488
489

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
490
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
491
492
            async frontend will be executed in a separate process as the
            model workers.
493
        log_requests: Whether to log the requests.
494
495
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
496
497
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
498
    """
499

Antoni Baum's avatar
Antoni Baum committed
500
501
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

502
503
504
505
506
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
507
                 start_engine_loop: bool = True,
508
                 **kwargs) -> None:
509
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
510
        self.engine_use_ray = engine_use_ray
511
        self.log_requests = log_requests
Antoni Baum's avatar
Antoni Baum committed
512
513
        self.engine = self._init_engine(*args, **kwargs)

514
515
516
517
518
519
520
521
522
523
524
525
526
527
        if self.engine_use_ray:
            print_warning_once(
                "DEPRECATED. `--engine-use-ray` is deprecated and will "
                "be removed in a future update. "
                "See https://github.com/vllm-project/vllm/issues/7045.")

            if envs.VLLM_ALLOW_ENGINE_USE_RAY:
                print_warning_once(
                    "VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray")
            else:
                raise ValueError("`--engine-use-ray` is deprecated. "
                                 "Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
                                 "force use it")

528
        self.background_loop: Optional[asyncio.Future] = None
529
530
531
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
532
        self._background_loop_unshielded: Optional[asyncio.Task] = None
533
        self.start_engine_loop = start_engine_loop
534
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
535

536
537
538
        # Lazy initialized fields
        self._request_tracker: RequestTracker

539
    @classmethod
540
541
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
542
543
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
544
545
546
547
548
549
550
551
552
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
553
554
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
555
        elif engine_config.device_config.device_type == "tpu":
556
557
558
559
560
561
562
563
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
564
565
566
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
567
568
569
570
571
572
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
573
574
575
576
577
578
579
580
581
582
583
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
584
        elif distributed_executor_backend == "ray":
585
            initialize_ray_cluster(engine_config.parallel_config)
586
587
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
588
589
590
591
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
592
593
594
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        if engine_args.engine_use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

        executor_class = cls._get_executor_cls(engine_config)

615
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
616
        engine = cls(
617
            executor_class.uses_ray,
yhu422's avatar
yhu422 committed
618
            engine_args.engine_use_ray,
619
620
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
621
622
623
624
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
625
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
626
        )
627
628
        return engine

629
630
    @property
    def is_running(self) -> bool:
631
        return (self.background_loop is not None
632
                and self._background_loop_unshielded is not None
633
634
635
636
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
637
638
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
639
640
641
642
643
644
645
646
647
648
649
650
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
651

652
653
654
655
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> "PreTrainedTokenizer":
656
        if self.engine_use_ray:
657
658
659
660
661
            return await self.engine.get_tokenizer.remote(  # type: ignore
                lora_request)

        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
662

663
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
664
        """Start the background loop."""
665
666
667
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
668
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
669
            raise RuntimeError("Background loop is already running.")
670
671
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
672
673
674
675

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
676
            partial(_log_task_completion, error_callback=self._error_callback))
677
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
678

679
680
681
682
683
684
685
686
687
688
689
690
691
692
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

Antoni Baum's avatar
Antoni Baum committed
693
694
    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
695
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
696
            engine_class = self._engine_class
697
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
698
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
699
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
700
701
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
702
703
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
704
705
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
706
707
708
709
710
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
711
712
        return engine_class(*args, **kwargs)

713
    async def engine_step(self, virtual_engine: int) -> bool:
714
715
716
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
717

718
719
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
720
721
722
723

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
724
725
            try:
                if self.engine_use_ray:
726
727
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
728
729
730
731
732
733
734
735
736
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
737

738
739
        if aborted_requests:
            await self._engine_abort(aborted_requests)
740

Zhuohan Li's avatar
Zhuohan Li committed
741
        if self.engine_use_ray:
742
            request_outputs = await self.engine.step.remote()  # type: ignore
743
        else:
744
            request_outputs = await self.engine.step_async(virtual_engine)
745

Antoni Baum's avatar
Antoni Baum committed
746
        # Put the outputs into the corresponding streams.
747
        finished = True
748
        for request_output in request_outputs:
749
            self._request_tracker.process_request_output(
750
                request_output, verbose=self.log_requests)
751
            finished = finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
752

753
        return not finished
754

Antoni Baum's avatar
Antoni Baum committed
755
756
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
757
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
758
759
760
761
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
762
763
764
765
766
767
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
768
        while True:
769
            if not any(has_requests_in_progress):
770
                logger.debug("Waiting for new requests...")
771
772
773
774
775
776
777
778
779
780
781
782
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
783
                await self._request_tracker.wait_for_new_requests()
784
                logger.debug("Got new requests!")
785
786
787
788
789
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
790
791
792
793

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
794
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
821
822
823
824
825
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
826
827
            await asyncio.sleep(0)

828
829
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
Antoni Baum's avatar
Antoni Baum committed
830
831
832
    async def add_request(
        self,
        request_id: str,
833
        inputs: PromptInputs,
834
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
835
        arrival_time: Optional[float] = None,
836
        lora_request: Optional[LoRARequest] = None,
837
        trace_headers: Optional[Mapping[str, str]] = None,
838
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
839
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
840
        if not self.is_running:
841
842
843
844
845
846
847
848
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
849

850
        stream = self._request_tracker.add_request(
851
            request_id,
852
            verbose=self.log_requests,
853
            inputs=inputs,
854
            params=params,
855
            arrival_time=arrival_time or time.time(),
856
            lora_request=lora_request,
857
            trace_headers=trace_headers,
858
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
859

860
        return stream.generator()
861

862
    async def generate(
863
        self,
864
        inputs: PromptInputs,
865
866
        sampling_params: SamplingParams,
        request_id: str,
867
        lora_request: Optional[LoRARequest] = None,
868
        trace_headers: Optional[Mapping[str, str]] = None,
869
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
870
    ) -> AsyncGenerator[RequestOutput, None]:
871
872
873
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
874
875
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
876
877

        Args:
878
879
880
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
881
882
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
883
            lora_request: LoRA request to use for generation, if any.
884
            trace_headers: OpenTelemetry trace headers.
885
886
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
887
888

        Yields:
889
890
            The output `RequestOutput` objects from the LLMEngine
            for the request.
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
934
        """
935
        async for output in await self.add_request(
936
                request_id,
937
                inputs,
938
                sampling_params,
939
                lora_request=lora_request,
940
                trace_headers=trace_headers,
941
                prompt_adapter_request=prompt_adapter_request,
942
        ):
943
            yield LLMEngine.validate_output(output, RequestOutput)
944
945
946

    async def encode(
        self,
947
        inputs: PromptInputs,
948
949
950
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
951
        trace_headers: Optional[Mapping[str, str]] = None,
952
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
953
954
955
956
957
958
959
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
960
961
962
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
963
964
965
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
966
            trace_headers: OpenTelemetry trace headers.
967
968

        Yields:
969
            The output `EmbeddingRequestOutput` objects from the LLMEngine
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1013
        async for output in await self.add_request(
1014
                request_id,
1015
                inputs,
1016
                pooling_params,
1017
                lora_request=lora_request,
1018
                trace_headers=trace_headers,
1019
        ):
1020
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1021

Antoni Baum's avatar
Antoni Baum committed
1022
1023
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1024

Antoni Baum's avatar
Antoni Baum committed
1025
1026
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1027

Antoni Baum's avatar
Antoni Baum committed
1028
1029
1030
        Args:
            request_id: The unique id of the request.
        """
1031
1032
1033
1034
1035
1036
1037
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1038
        return self._abort(request_id)
1039

Antoni Baum's avatar
Antoni Baum committed
1040
    def _abort(self, request_id: str) -> None:
1041
1042
1043
1044
1045
1046
1047
1048
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1049
        self._request_tracker.abort_request(request_id,
1050
                                            exception=asyncio.CancelledError,
1051
                                            verbose=self.log_requests)
1052

1053
1054
1055
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
1056
            return await self.engine.get_model_config.remote()  # type: ignore
1057
1058
1059
        else:
            return self.engine.get_model_config()

1060
1061
1062
1063
1064
1065
1066
1067
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_parallel_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_parallel_config()

1068
1069
1070
1071
1072
1073
1074
1075
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_scheduler_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_scheduler_config()

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_lora_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_lora_config()

1092
1093
1094
1095
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1096
        if self.engine_use_ray:
1097
1098
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
1099
1100
        else:
            self.engine.do_log_stats()
1101

1102
    async def check_health(self) -> None:
1103
1104
1105
1106
1107
1108
1109
1110
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
1111
                await self.engine.check_health.remote()  # type: ignore
1112
1113
1114
1115
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
1116
        logger.debug("Health check took %fs", time.perf_counter() - t)
1117
1118
1119
1120
1121
1122
1123

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.add_logger.remote(  # type: ignore
                    logger_name=logger_name, logger=logger))
        else:
            self.engine.add_logger(logger_name=logger_name, logger=logger)

    def remove_logger(self, logger_name: str) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.remove_logger.remote(  # type: ignore
                    logger_name=logger_name))
        else:
            self.engine.remove_logger(logger_name=logger_name)