async_llm_engine.py 51.4 KB
Newer Older
1
2
import asyncio
import time
3
from dataclasses import dataclass
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
6
                    Optional, Set, Tuple, Type, Union)
7

8
import torch
9
from typing_extensions import assert_never
10

11
import vllm.envs as envs
12
13
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
14
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from vllm.engine.arg_utils import AsyncEngineArgs
16
from vllm.engine.async_timeout import asyncio_timeout
17
18
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
                                    PromptComponents)
19
from vllm.engine.metrics_types import StatLoggerBase
20
from vllm.executor.executor_base import ExecutorAsyncBase
21
from vllm.executor.ray_utils import initialize_ray_cluster, ray
22
23
24
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
                         SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.logger import init_logger
26
from vllm.lora.request import LoRARequest
27
28
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
29
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
30
from vllm.sampling_params import SamplingParams
31
32
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
                           SequenceGroupMetadata)
33
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
34
from vllm.usage.usage_lib import UsageContext
35
from vllm.utils import print_warning_once
36
37

logger = init_logger(__name__)
38
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
39

Antoni Baum's avatar
Antoni Baum committed
40

41
42
43
44
class AsyncEngineDeadError(RuntimeError):
    pass


45
46
47
48
49
50
51
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
52
53

    exception = None
54
    try:
55
56
57
58
59
60
61
62
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
63
64
65
66
67
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
68
            "Task finished unexpectedly. This should never happen! "
69
            "Please open an issue on Github. See stack trace above for the "
70
            "actual cause.") from e
71
72


73
74
75
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
76
class AsyncStream:
77
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
78
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
79

80
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
81
        self.request_id = request_id
82
        self._cancel = cancel
83
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
84
85
        self._finished = False

86
87
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
88
89
90
91
        if self._finished:
            return
        self._queue.put_nowait(item)

92
93
94
95
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
96
97
98
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
99
                exception if exception is not None else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
100
101
102
103
104

    @property
    def finished(self) -> bool:
        return self._finished

105
106
107
108
109
110
111
112
113
114
115
116
117
118
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
            while not self._finished:
                result = await self._queue.get()
                if isinstance(result, Exception):
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
119
120


121
122
123
124
125
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
126
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
127
128
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
129
        self.new_requests_event = asyncio.Event()
130
131
132
133

    def __contains__(self, item):
        return item in self._request_streams

134
135
    def __len__(self) -> int:
        return len(self._request_streams)
136
137
138
139
140
141
142

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
143
            self.abort_request(request_id, exception=exc)
144
        else:
145
            # NB: tuple() used here because self.abort_request pops the stream
146
            # out of self._request_streams, so we can't iterate on it directly
147
148
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
149
150

    def process_request_output(self,
151
152
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
153
154
155
156
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
157
        finished = request_output.finished
158

159
160
161
162
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
163
164
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
165
        if stream is not None:
166
            stream.put(request_output)
167
168
169
170
171
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
172

173
174
    def process_exception(self,
                          request_id: str,
175
                          exception: BaseException,
176
177
178
179
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
180
            logger.info("Finished request %s.", request_id)
181
        self.abort_request(request_id, exception=exception)
182

183
184
185
186
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
187
188
189
190
191
192
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

193
194
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
195
196
197
198
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
199
200
201

        self.new_requests_event.set()

202
203
204
        if verbose:
            logger.info("Added request %s.", request_id)

205
206
        return stream

207
208
209
    def abort_request(self,
                      request_id: str,
                      *,
210
211
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
212
                      verbose: bool = False) -> None:
213
214
        """Abort a request during next background loop iteration."""
        if verbose:
215
            logger.info("Aborted request %s.", request_id)
216

217
        self._aborted_requests.put_nowait(request_id)
218

219
220
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
221
            stream.finish(exception=exception)
222

223
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
224
225
        """Get the new requests and finished requests to be
        sent to the engine."""
226
        new_requests: List[Dict] = []
227
228
        finished_requests: Set[str] = set()

229
230
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
231
232
233
234
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
235
236
            request_id = stream.request_id
            if request_id in finished_requests:
237
                # The request has already been aborted.
238
239
240
241
242
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
243
244

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
245

246
    async def wait_for_new_requests(self):
247
248
249
250
251
252
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
253

Antoni Baum's avatar
Antoni Baum committed
254

255
256
257
258
259
260
261
262
@dataclass
class SchedulerOutputState:
    """Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
    last_output: Optional[SamplerOutput] = None
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None


Antoni Baum's avatar
Antoni Baum committed
263
264
265
class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

266
267
268
269
270
271
272
273
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        pipeline_parallel_size = \
            self.parallel_config.pipeline_parallel_size
        self.cached_scheduler_outputs = [
            SchedulerOutputState() for _ in range(pipeline_parallel_size)
        ]

274
    async def step_async(
275
276
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
277
278
279
280
281
282
283
284
285
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
            seq_group_metadata_list, scheduler_outputs = self.scheduler[
                virtual_engine].schedule()

            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
                    virtual_engine, seq_group_metadata_list, scheduler_outputs)

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
307

308
        if not scheduler_outputs.is_empty():
309
310
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
311
312
313
314
315
316
317
318

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

319
320
321
322
323
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
324
                virtual_engine=virtual_engine,
325
326
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
327
328
329
330
331
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
            # Execute the model.
332
            output = await self.model_executor.execute_model_async(
333
                execute_model_req)
334
335
336
337
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
                self._update_cached_scheduler_output(virtual_engine, output)
338
339
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
340

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
            # clear the cache if we have finished all the steps
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
            request_outputs = self._process_model_outputs(
                output, scheduler_outputs.scheduled_seq_groups,
                scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
        else:
            request_outputs = []
Antoni Baum's avatar
Antoni Baum committed
356

357
        # Log stats.
358
        self.do_log_stats(scheduler_outputs, output)
359

360
361
362
        # Tracing
        self.do_tracing(scheduler_outputs)

363
364
        return request_outputs

365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    def _has_remaining_steps(
        self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
    ) -> bool:
        if (not self.scheduler_config.is_multi_step
                or not seq_group_metadata_list):
            return False

        # TODO(will) this is a sanity check for nowto make sure that all the
        # seqs are on the same steps. Eventually we will want to do some sort of
        # dynamic scheduling when doing multi-step decoding.
        ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
        if any([
                seq_group.state.remaining_steps != ref_remaining_steps
                for seq_group in seq_group_metadata_list[1:]
        ]):
            raise AssertionError(("All running sequence groups should "
                                  "have the same remaining steps."))

        return ref_remaining_steps > 0

    def _cache_scheduler_outputs_for_multi_step(
            self, virtual_engine: int,
            seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
            scheduler_outputs: SchedulerOutputs) -> None:
        self.cached_scheduler_outputs[
            virtual_engine].seq_group_metadata_list = seq_group_metadata_list
        self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
            scheduler_outputs
        self.cached_scheduler_outputs[virtual_engine].last_output = None

    def _get_last_sampled_token_ids(
            self, virtual_engine: int) -> Optional[torch.Tensor]:
        cached_last_output = self.cached_scheduler_outputs[
            virtual_engine].last_output
        if (self.scheduler_config.is_multi_step
                and self.parallel_config.pipeline_parallel_size > 1
                and cached_last_output is not None
                and cached_last_output.sampled_token_ids_cpu is not None):
            return cached_last_output.sampled_token_ids_cpu
        return None

    def _update_cached_scheduler_output(
            self, virtual_engine: int,
            output: List[Optional[SamplerOutput]]) -> None:
        if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
                and output[0] is not None):
            last_output = output[-1]
            assert last_output is not None
            assert last_output.sampled_token_ids_cpu is not None
            assert last_output.sampled_token_ids is None
            assert last_output.sampled_token_probs is None
            self.cached_scheduler_outputs[
                virtual_engine].last_output = last_output

419
420
421
422
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

423
424
425
426
427
428
429
    async def _tokenize_prompt_async(
        self,
        prompt: str,
        request_id: str,
        lora_request: Optional[LoRARequest],
    ) -> List[int]:
        """Async version of :meth:`_tokenize_prompt`."""
430
431
        tokenizer = self.get_tokenizer_group(
            missing_msg="prompts must be None if skip_tokenizer_init is True")
432
433
434
435
436
437

        return await tokenizer.encode_async(request_id=request_id,
                                            prompt=prompt,
                                            lora_request=lora_request)

    async def _extract_prompt_components_async(
438
        self,
439
        inputs: SingletonPromptInputs,
440
        request_id: str,
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        lora_request: Optional[LoRARequest] = None,
    ) -> PromptComponents:
        """Async version of :meth:`_extract_prompt_components`."""
        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = await self._tokenize_prompt_async(
                prompt,
                request_id=request_id,
                lora_request=lora_request,
            )
            multi_modal_data = None
        elif isinstance(inputs, dict):
            if "prompt_token_ids" in inputs:
                prompt = None
                prompt_token_ids = inputs["prompt_token_ids"]
            else:
                # NOTE: This extra assignment is required to pass mypy
                prompt = parsed_prompt = inputs["prompt"]
                prompt_token_ids = await self._tokenize_prompt_async(
                    parsed_prompt,
                    request_id=request_id,
                    lora_request=lora_request,
                )

            multi_modal_data = inputs.get("multi_modal_data")
        else:
            assert_never(inputs)

        return prompt, prompt_token_ids, multi_modal_data

    async def _process_encoder_decoder_prompt_async(
        self,
473
        inputs: PromptInputs,
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        request_id: str,
    ) -> EncoderDecoderLLMInputs:
        """Async version of :meth:`_process_encoder_decoder_prompt`."""
        encoder_comps: PromptComponents
        decoder_comps: DecoderPromptComponents

        if is_explicit_encoder_decoder_prompt(inputs):
            encoder_task = self._extract_prompt_components_async(
                inputs["encoder_prompt"],
                request_id=request_id,
            )

            if (decoder_input := inputs["decoder_prompt"]) is None:
                encoder_comps = await encoder_task
                decoder_comps = None, None, None
            else:
                decoder_task = self._extract_prompt_components_async(
                    decoder_input,
                    request_id=request_id,
                )

                encoder_comps, decoder_comps = await asyncio.gather(
                    encoder_task, decoder_task)
        else:
            encoder_comps = await self._extract_prompt_components_async(
                inputs,
                request_id=request_id,
            )

            decoder_comps = None, None, None

        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

    async def _process_decoder_only_prompt_async(
        self,
        inputs: SingletonPromptInputs,
        request_id: str,
511
        lora_request: Optional[LoRARequest] = None,
512
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
513
    ) -> LLMInputs:
514
515
516
517
518
519
        """Async version of :meth:`_process_decoder_only_prompt`."""
        prompt_comps = await self._extract_prompt_components_async(
            inputs,
            request_id=request_id,
            lora_request=lora_request,
        )
520

521
522
523
524
        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )
525

526
527
528
529
530
531
532
533
534
535
536
537
538
    async def process_model_inputs_async(
        self,
        inputs: PromptInputs,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
        """Async version of :meth:`process_model_inputs`."""
        if self.is_encoder_decoder_model():
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            model_inputs = await self._process_encoder_decoder_prompt_async(
                inputs,
539
                request_id=request_id,
540
            )
541
        else:
542
543
544
            if is_explicit_encoder_decoder_prompt(inputs):
                raise ValueError("Cannot pass encoder-decoder prompt "
                                 "to decoder-only models")
545

546
547
548
549
550
551
552
            # Decoder-only operation
            model_inputs = await self._process_decoder_only_prompt_async(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )
553

554
        return self.input_processor(model_inputs)
555
556

    async def add_request_async(
557
558
559
560
561
562
563
564
        self,
        request_id: str,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
565
    ) -> None:
566
        """Async version of :meth:`add_request`."""
567
568
569
570
571
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
572
573

        processed_inputs = await self.process_model_inputs_async(
574
            inputs,
575
576
            request_id=request_id,
            lora_request=lora_request,
577
578
            prompt_adapter_request=prompt_adapter_request,
        )
579
580

        self._add_processed_request(
581
            request_id=request_id,
582
583
584
585
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
586
            prompt_adapter_request=prompt_adapter_request,
587
            trace_headers=trace_headers,
588
        )
589

590
    async def check_health_async(self) -> None:
591
592
        if self.tokenizer:
            self.tokenizer.check_health()
593
        self.model_executor.check_health()
594

595

596
class AsyncLLMEngine:
597
    """An asynchronous wrapper for :class:`LLMEngine`.
598

599
600
601
602
603
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
604
605
606
607
608

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
609
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
610
611
            async frontend will be executed in a separate process as the
            model workers.
612
        log_requests: Whether to log the requests.
613
614
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
615
616
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
617
    """
618

Antoni Baum's avatar
Antoni Baum committed
619
620
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

621
622
623
624
625
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
626
                 start_engine_loop: bool = True,
627
                 **kwargs) -> None:
628
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
629
        self.engine_use_ray = engine_use_ray
630
        self.log_requests = log_requests
Antoni Baum's avatar
Antoni Baum committed
631
632
        self.engine = self._init_engine(*args, **kwargs)

633
634
635
636
637
638
639
640
641
642
643
644
645
646
        if self.engine_use_ray:
            print_warning_once(
                "DEPRECATED. `--engine-use-ray` is deprecated and will "
                "be removed in a future update. "
                "See https://github.com/vllm-project/vllm/issues/7045.")

            if envs.VLLM_ALLOW_ENGINE_USE_RAY:
                print_warning_once(
                    "VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray")
            else:
                raise ValueError("`--engine-use-ray` is deprecated. "
                                 "Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
                                 "force use it")

647
        self.background_loop: Optional[asyncio.Future] = None
648
649
650
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
651
        self._background_loop_unshielded: Optional[asyncio.Task] = None
652
        self.start_engine_loop = start_engine_loop
653
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
654

655
656
657
        # Lazy initialized fields
        self._request_tracker: RequestTracker

658
    @classmethod
659
660
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
661
662
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
663
664
665
666
667
668
669
670
671
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
672
673
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
674
        elif engine_config.device_config.device_type == "tpu":
675
676
677
678
679
680
681
682
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
683
684
685
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
686
687
688
689
690
691
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
692
693
694
695
696
697
698
699
700
701
702
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
703
        elif distributed_executor_backend == "ray":
704
            initialize_ray_cluster(engine_config.parallel_config)
705
706
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
707
708
709
710
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
711
712
713
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        if engine_args.engine_use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

        executor_class = cls._get_executor_cls(engine_config)

734
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
735
        engine = cls(
736
            executor_class.uses_ray,
yhu422's avatar
yhu422 committed
737
            engine_args.engine_use_ray,
738
739
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
740
741
742
743
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
744
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
745
        )
746
747
        return engine

748
749
    @property
    def is_running(self) -> bool:
750
        return (self.background_loop is not None
751
                and self._background_loop_unshielded is not None
752
753
754
755
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
756
757
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
758
759
760
761
762
763
764
765
766
767
768
769
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
770

771
772
773
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
774
    ) -> AnyTokenizer:
775
        if self.engine_use_ray:
776
777
778
779
780
            return await self.engine.get_tokenizer.remote(  # type: ignore
                lora_request)

        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
781

782
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
783
        """Start the background loop."""
784
785
786
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
787
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
788
            raise RuntimeError("Background loop is already running.")
789
790
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
791
792
793
794

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
795
            partial(_log_task_completion, error_callback=self._error_callback))
796
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
797

798
799
800
801
802
803
804
805
806
807
808
809
810
811
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

Antoni Baum's avatar
Antoni Baum committed
812
813
    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
814
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
815
            engine_class = self._engine_class
816
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
817
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
818
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
819
820
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
821
822
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
823
824
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
825
826
827
828
829
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
830
831
        return engine_class(*args, **kwargs)

832
    async def engine_step(self, virtual_engine: int) -> bool:
833
834
835
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
836

837
838
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
839
840
841
842

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
843
844
            try:
                if self.engine_use_ray:
845
846
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
847
848
849
850
851
852
853
854
855
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
856

857
858
        if aborted_requests:
            await self._engine_abort(aborted_requests)
859

Zhuohan Li's avatar
Zhuohan Li committed
860
        if self.engine_use_ray:
861
            request_outputs = await self.engine.step.remote()  # type: ignore
862
        else:
863
            request_outputs = await self.engine.step_async(virtual_engine)
864

Antoni Baum's avatar
Antoni Baum committed
865
        # Put the outputs into the corresponding streams.
866
        finished = True
867
        for request_output in request_outputs:
868
            self._request_tracker.process_request_output(
869
                request_output, verbose=self.log_requests)
870
            finished = finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
871

872
        return not finished
873

Antoni Baum's avatar
Antoni Baum committed
874
875
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
876
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
877
878
879
880
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
881
882
883
884
885
886
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
887
        while True:
888
            if not any(has_requests_in_progress):
889
                logger.debug("Waiting for new requests...")
890
891
892
893
894
895
896
897
898
899
900
901
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
902
                await self._request_tracker.wait_for_new_requests()
903
                logger.debug("Got new requests!")
904
905
906
907
908
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
909
910
911
912

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
913
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
940
941
942
943
944
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
945
946
            await asyncio.sleep(0)

947
948
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
Antoni Baum's avatar
Antoni Baum committed
949
950
951
    async def add_request(
        self,
        request_id: str,
952
        inputs: PromptInputs,
953
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
954
        arrival_time: Optional[float] = None,
955
        lora_request: Optional[LoRARequest] = None,
956
        trace_headers: Optional[Mapping[str, str]] = None,
957
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
958
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
959
        if not self.is_running:
960
961
962
963
964
965
966
967
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
968

969
        stream = self._request_tracker.add_request(
970
            request_id,
971
            verbose=self.log_requests,
972
            inputs=inputs,
973
            params=params,
974
            arrival_time=arrival_time or time.time(),
975
            lora_request=lora_request,
976
            trace_headers=trace_headers,
977
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
978

979
        return stream.generator()
980

981
    async def generate(
982
        self,
983
        inputs: PromptInputs,
984
985
        sampling_params: SamplingParams,
        request_id: str,
986
        lora_request: Optional[LoRARequest] = None,
987
        trace_headers: Optional[Mapping[str, str]] = None,
988
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
989
    ) -> AsyncGenerator[RequestOutput, None]:
990
991
992
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
993
994
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
995
996

        Args:
997
998
999
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
1000
1001
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
1002
            lora_request: LoRA request to use for generation, if any.
1003
            trace_headers: OpenTelemetry trace headers.
1004
1005
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
1006
1007

        Yields:
1008
1009
            The output `RequestOutput` objects from the LLMEngine
            for the request.
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1053
        """
1054
        async for output in await self.add_request(
1055
                request_id,
1056
                inputs,
1057
                sampling_params,
1058
                lora_request=lora_request,
1059
                trace_headers=trace_headers,
1060
                prompt_adapter_request=prompt_adapter_request,
1061
        ):
1062
            yield LLMEngine.validate_output(output, RequestOutput)
1063
1064
1065

    async def encode(
        self,
1066
        inputs: PromptInputs,
1067
1068
1069
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1070
        trace_headers: Optional[Mapping[str, str]] = None,
1071
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
1072
1073
1074
1075
1076
1077
1078
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1079
1080
1081
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
1082
1083
1084
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1085
            trace_headers: OpenTelemetry trace headers.
1086
1087

        Yields:
1088
            The output `EmbeddingRequestOutput` objects from the LLMEngine
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1132
        async for output in await self.add_request(
1133
                request_id,
1134
                inputs,
1135
                pooling_params,
1136
                lora_request=lora_request,
1137
                trace_headers=trace_headers,
1138
        ):
1139
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1140

Antoni Baum's avatar
Antoni Baum committed
1141
1142
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1143

Antoni Baum's avatar
Antoni Baum committed
1144
1145
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1146

Antoni Baum's avatar
Antoni Baum committed
1147
1148
1149
        Args:
            request_id: The unique id of the request.
        """
1150
1151
1152
1153
1154
1155
1156
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1157
        return self._abort(request_id)
1158

Antoni Baum's avatar
Antoni Baum committed
1159
    def _abort(self, request_id: str) -> None:
1160
1161
1162
1163
1164
1165
1166
1167
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1168
        self._request_tracker.abort_request(request_id,
1169
                                            exception=asyncio.CancelledError,
1170
                                            verbose=self.log_requests)
1171

1172
1173
1174
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
1175
            return await self.engine.get_model_config.remote()  # type: ignore
1176
1177
1178
        else:
            return self.engine.get_model_config()

1179
1180
1181
1182
1183
1184
1185
1186
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_parallel_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_parallel_config()

1187
1188
1189
1190
1191
1192
1193
1194
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_scheduler_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_scheduler_config()

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_lora_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_lora_config()

1211
1212
1213
1214
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1215
        if self.engine_use_ray:
1216
1217
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
1218
1219
        else:
            self.engine.do_log_stats()
1220

1221
    async def check_health(self) -> None:
1222
1223
1224
1225
1226
1227
1228
1229
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
1230
                await self.engine.check_health.remote()  # type: ignore
1231
1232
1233
1234
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
1235
        logger.debug("Health check took %fs", time.perf_counter() - t)
1236
1237
1238
1239
1240
1241
1242

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.add_logger.remote(  # type: ignore
                    logger_name=logger_name, logger=logger))
        else:
            self.engine.add_logger(logger_name=logger_name, logger=logger)

    def remove_logger(self, logger_name: str) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.remove_logger.remote(  # type: ignore
                    logger_name=logger_name))
        else:
            self.engine.remove_logger(logger_name=logger_name)