async_llm_engine.py 51.2 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
6

7
from typing_extensions import assert_never
8

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
15
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
16
                                    PromptComponents, SchedulerOutputState)
17
from vllm.engine.metrics_types import StatLoggerBase
18
from vllm.executor.executor_base import ExecutorAsyncBase
19
from vllm.executor.ray_utils import initialize_ray_cluster, ray
20
21
22
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
                         SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.logger import init_logger
24
from vllm.lora.request import LoRARequest
25
26
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
27
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
28
from vllm.sampling_params import SamplingParams
29
from vllm.sequence import ExecuteModelRequest, SamplerOutput
30
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
31
from vllm.usage.usage_lib import UsageContext
32
from vllm.utils import print_warning_once
33
34

logger = init_logger(__name__)
35
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
36

Antoni Baum's avatar
Antoni Baum committed
37

38
39
40
41
class AsyncEngineDeadError(RuntimeError):
    pass


42
43
44
45
46
47
48
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
49
50

    exception = None
51
    try:
52
53
54
55
56
57
58
59
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
60
61
62
63
64
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
65
            "Task finished unexpectedly. This should never happen! "
66
            "Please open an issue on Github. See stack trace above for the "
67
            "actual cause.") from e
68
69


70
71
72
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
73
class AsyncStream:
74
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
75
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
76

77
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
78
        self.request_id = request_id
79
        self._cancel = cancel
80
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
81
82
        self._finished = False

83
84
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
85
86
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
87

88
89
90
91
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
92
93
94
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
95
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
96
97
98
99
100

    @property
    def finished(self) -> bool:
        return self._finished

101
102
103
104
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
105
            while True:
106
                result = await self._queue.get()
107
                if self._is_raisable(result):
108
109
110
111
112
113
114
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
115

116
117
118
119
120
121
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
122

123
124
125
126
127
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
128
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
129
130
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
131
        self.new_requests_event = asyncio.Event()
132
133
134
135

    def __contains__(self, item):
        return item in self._request_streams

136
137
    def __len__(self) -> int:
        return len(self._request_streams)
138
139
140
141
142
143
144

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
145
            self.abort_request(request_id, exception=exc)
146
        else:
147
            # NB: tuple() used here because self.abort_request pops the stream
148
            # out of self._request_streams, so we can't iterate on it directly
149
150
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
151
152

    def process_request_output(self,
153
154
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
155
156
157
158
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
159
        finished = request_output.finished
160

161
162
163
164
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
165
166
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
167
        if stream is not None:
168
            stream.put(request_output)
169
170
171
172
173
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
174

175
176
    def process_exception(self,
                          request_id: str,
177
                          exception: BaseException,
178
179
180
181
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
182
            logger.info("Finished request %s.", request_id)
183
        self.abort_request(request_id, exception=exception)
184

185
186
187
188
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
189
190
191
192
193
194
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

195
196
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
197
198
199
200
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
201
202
203

        self.new_requests_event.set()

204
205
206
        if verbose:
            logger.info("Added request %s.", request_id)

207
208
        return stream

209
210
211
    def abort_request(self,
                      request_id: str,
                      *,
212
213
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
214
                      verbose: bool = False) -> None:
215
216
        """Abort a request during next background loop iteration."""
        if verbose:
217
            logger.info("Aborted request %s.", request_id)
218

219
        self._aborted_requests.put_nowait(request_id)
220

221
222
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
223
            stream.finish(exception=exception)
224

225
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
226
227
        """Get the new requests and finished requests to be
        sent to the engine."""
228
        new_requests: List[Dict] = []
229
230
        finished_requests: Set[str] = set()

231
232
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
233
234
235
236
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
237
238
            request_id = stream.request_id
            if request_id in finished_requests:
239
                # The request has already been aborted.
240
241
242
243
244
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
245
246

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
247

248
    async def wait_for_new_requests(self):
249
250
251
252
253
254
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
255

Antoni Baum's avatar
Antoni Baum committed
256
257
258
259

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

260
261
262
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

263
    async def step_async(
264
265
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
266
267
268
269
270
271
272
273
274
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
275
276
277
278
279
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
280
281
        allow_async_output_proc = cached_outputs.allow_async_output_proc

282
283
        ctx = self.scheduler_contexts[virtual_engine]

284
285
286
287
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
288
289
290
291

            # Clear outputs on scheduler iteration start
            ctx.request_outputs.clear()

292
293
294
295
296
297
298
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

            # If current scheduler iteration has no async postprocessor,
            # then we need first to drain the pending async postprocessor
            # before moving forward
299
300
301
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
                self._process_model_outputs(virtual_engine=virtual_engine,
                                            is_async=True)
302
303
304
305
306
307

            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
308
309
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
310
311
312

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
313

314
315
316
        assert not (self.scheduler_config.is_multi_step and \
            allow_async_output_proc)

317
        if not scheduler_outputs.is_empty():
318
319
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
320
321
322
323
324
325
326
327

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

328
329
330
331
332
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
333
                virtual_engine=virtual_engine,
334
335
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
336
337
338
339
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
340
341

            if allow_async_output_proc:
342
343
                execute_model_req.async_callback = self.async_callback[
                    virtual_engine]
344

345
            # Execute the model.
346
            output = await self.model_executor.execute_model_async(
347
                execute_model_req)
348
349
350
351
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
                self._update_cached_scheduler_output(virtual_engine, output)
352
        else:
353
            if len(ctx.output_queue) > 0:
354
                assert not self.scheduler_config.is_multi_step
355
356
                self._process_model_outputs(virtual_engine=virtual_engine,
                                            is_async=True)
357
            output = []
Antoni Baum's avatar
Antoni Baum committed
358

359
360
361
362
363
364
365
366
367
368
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
            # clear the cache if we have finished all the steps
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
369

370
            # Cache results in engine
371
            ctx.output_queue.append(
372
                (output, seq_group_metadata_list, scheduler_outputs))
373

374
375
376
377
378
379
380
381
382
            if output and allow_async_output_proc:
                assert len(
                    output
                ) == 1, "Multi step decoding does not work with async output processing."  # noqa: E501
                self._advance_to_next_step(
                    output[0], seq_group_metadata_list,
                    scheduler_outputs.scheduled_seq_groups)

            if not allow_async_output_proc:
383
384
                self._process_model_outputs(virtual_engine=virtual_engine,
                                            is_async=False)
385
386
387
388
389
390
391
392

                # Log stats.
                self.do_log_stats(scheduler_outputs, output)

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
393
394
395
396
397
398
399
400
401
            ctx.request_outputs = []

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
                assert not self.scheduler_config.is_multi_step
                self._process_model_outputs(virtual_engine=virtual_engine,
                                            is_async=True)
            assert len(ctx.output_queue) == 0
402

403
        return ctx.request_outputs
404

405
406
407
408
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

409
410
411
412
413
414
415
    async def _tokenize_prompt_async(
        self,
        prompt: str,
        request_id: str,
        lora_request: Optional[LoRARequest],
    ) -> List[int]:
        """Async version of :meth:`_tokenize_prompt`."""
416
417
        tokenizer = self.get_tokenizer_group(
            missing_msg="prompts must be None if skip_tokenizer_init is True")
418
419
420
421
422
423

        return await tokenizer.encode_async(request_id=request_id,
                                            prompt=prompt,
                                            lora_request=lora_request)

    async def _extract_prompt_components_async(
424
        self,
425
        inputs: SingletonPromptInputs,
426
        request_id: str,
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        lora_request: Optional[LoRARequest] = None,
    ) -> PromptComponents:
        """Async version of :meth:`_extract_prompt_components`."""
        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = await self._tokenize_prompt_async(
                prompt,
                request_id=request_id,
                lora_request=lora_request,
            )
            multi_modal_data = None
        elif isinstance(inputs, dict):
            if "prompt_token_ids" in inputs:
                prompt = None
                prompt_token_ids = inputs["prompt_token_ids"]
            else:
                # NOTE: This extra assignment is required to pass mypy
                prompt = parsed_prompt = inputs["prompt"]
                prompt_token_ids = await self._tokenize_prompt_async(
                    parsed_prompt,
                    request_id=request_id,
                    lora_request=lora_request,
                )

            multi_modal_data = inputs.get("multi_modal_data")
        else:
            assert_never(inputs)

        return prompt, prompt_token_ids, multi_modal_data

    async def _process_encoder_decoder_prompt_async(
        self,
459
        inputs: PromptInputs,
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
        request_id: str,
    ) -> EncoderDecoderLLMInputs:
        """Async version of :meth:`_process_encoder_decoder_prompt`."""
        encoder_comps: PromptComponents
        decoder_comps: DecoderPromptComponents

        if is_explicit_encoder_decoder_prompt(inputs):
            encoder_task = self._extract_prompt_components_async(
                inputs["encoder_prompt"],
                request_id=request_id,
            )

            if (decoder_input := inputs["decoder_prompt"]) is None:
                encoder_comps = await encoder_task
                decoder_comps = None, None, None
            else:
                decoder_task = self._extract_prompt_components_async(
                    decoder_input,
                    request_id=request_id,
                )

                encoder_comps, decoder_comps = await asyncio.gather(
                    encoder_task, decoder_task)
        else:
            encoder_comps = await self._extract_prompt_components_async(
                inputs,
                request_id=request_id,
            )

            decoder_comps = None, None, None

        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

    async def _process_decoder_only_prompt_async(
        self,
        inputs: SingletonPromptInputs,
        request_id: str,
497
        lora_request: Optional[LoRARequest] = None,
498
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
499
    ) -> LLMInputs:
500
501
502
503
504
505
        """Async version of :meth:`_process_decoder_only_prompt`."""
        prompt_comps = await self._extract_prompt_components_async(
            inputs,
            request_id=request_id,
            lora_request=lora_request,
        )
506

507
508
509
510
        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )
511

512
513
514
515
516
517
518
519
520
521
522
523
524
    async def process_model_inputs_async(
        self,
        inputs: PromptInputs,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
        """Async version of :meth:`process_model_inputs`."""
        if self.is_encoder_decoder_model():
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            model_inputs = await self._process_encoder_decoder_prompt_async(
                inputs,
525
                request_id=request_id,
526
            )
527
        else:
528
529
530
            if is_explicit_encoder_decoder_prompt(inputs):
                raise ValueError("Cannot pass encoder-decoder prompt "
                                 "to decoder-only models")
531

532
533
534
535
536
537
538
            # Decoder-only operation
            model_inputs = await self._process_decoder_only_prompt_async(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )
539

540
        return self.input_processor(model_inputs)
541
542

    async def add_request_async(
543
544
545
546
547
548
549
550
        self,
        request_id: str,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
551
    ) -> None:
552
        """Async version of :meth:`add_request`."""
553
554
555
556
557
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
558
559

        processed_inputs = await self.process_model_inputs_async(
560
            inputs,
561
562
            request_id=request_id,
            lora_request=lora_request,
563
564
            prompt_adapter_request=prompt_adapter_request,
        )
565
566

        self._add_processed_request(
567
            request_id=request_id,
568
569
570
571
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
572
            prompt_adapter_request=prompt_adapter_request,
573
            trace_headers=trace_headers,
574
        )
575

576
    async def check_health_async(self) -> None:
577
578
        if self.tokenizer:
            self.tokenizer.check_health()
579
        self.model_executor.check_health()
580

581

582
class AsyncLLMEngine:
583
    """An asynchronous wrapper for :class:`LLMEngine`.
584

585
586
587
588
589
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
590
591
592
593
594

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
595
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
596
597
            async frontend will be executed in a separate process as the
            model workers.
598
        log_requests: Whether to log the requests.
599
600
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
601
602
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
603
    """
604

Antoni Baum's avatar
Antoni Baum committed
605
606
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

607
608
609
610
611
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
612
                 start_engine_loop: bool = True,
613
                 **kwargs) -> None:
614
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
615
        self.engine_use_ray = engine_use_ray
616
        self.log_requests = log_requests
Antoni Baum's avatar
Antoni Baum committed
617
618
        self.engine = self._init_engine(*args, **kwargs)

619
620
621
622
623
624
625
626
627
628
629
630
631
632
        if self.engine_use_ray:
            print_warning_once(
                "DEPRECATED. `--engine-use-ray` is deprecated and will "
                "be removed in a future update. "
                "See https://github.com/vllm-project/vllm/issues/7045.")

            if envs.VLLM_ALLOW_ENGINE_USE_RAY:
                print_warning_once(
                    "VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray")
            else:
                raise ValueError("`--engine-use-ray` is deprecated. "
                                 "Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
                                 "force use it")

633
        self.background_loop: Optional[asyncio.Future] = None
634
635
636
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
637
        self._background_loop_unshielded: Optional[asyncio.Task] = None
638
        self.start_engine_loop = start_engine_loop
639
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
640

641
642
643
        # Lazy initialized fields
        self._request_tracker: RequestTracker

644
    @classmethod
645
646
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
647
648
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
649
650
651
652
653
654
655
656
657
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
658
659
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
660
        elif engine_config.device_config.device_type == "tpu":
661
662
663
664
665
666
667
668
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
669
670
671
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
672
673
674
675
676
677
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
678
679
680
681
682
683
684
685
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
686
687
688
689
690
            elif distributed_executor_backend == "mp":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
691
692
693
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
694
        elif distributed_executor_backend == "ray":
695
            initialize_ray_cluster(engine_config.parallel_config)
696
697
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
698
699
700
701
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
702
703
704
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        if engine_args.engine_use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

        executor_class = cls._get_executor_cls(engine_config)

725
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
726
        engine = cls(
727
            executor_class.uses_ray,
yhu422's avatar
yhu422 committed
728
            engine_args.engine_use_ray,
729
730
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
731
732
733
734
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
735
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
736
        )
737
738
        return engine

739
740
    @property
    def is_running(self) -> bool:
741
        return (self.background_loop is not None
742
                and self._background_loop_unshielded is not None
743
744
745
746
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
747
748
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
749
750
751
752
753
754
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

755
756
757
758
759
    @property
    def limit_concurrency(self) -> Optional[int]:
        """Maximum number of concurrently running requests."""
        return None

760
761
762
763
764
765
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
766

767
768
769
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
770
    ) -> AnyTokenizer:
771
        if self.engine_use_ray:
772
773
774
775
776
            return await self.engine.get_tokenizer.remote(  # type: ignore
                lora_request)

        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
777

778
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
779
        """Start the background loop."""
780
781
782
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
783
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
784
            raise RuntimeError("Background loop is already running.")
785
786
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
787
788
789
790

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
791
            partial(_log_task_completion, error_callback=self._error_callback))
792
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
793

794
795
796
797
798
799
800
801
802
803
804
805
806
807
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

Antoni Baum's avatar
Antoni Baum committed
808
809
    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
810
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
811
            engine_class = self._engine_class
812
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
813
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
814
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
815
816
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
817
818
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
819
820
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
821
822
823
824
825
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
826
827
        return engine_class(*args, **kwargs)

828
    async def engine_step(self, virtual_engine: int) -> bool:
829
830
831
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
832

833
834
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
835
836
837
838

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
839
840
            try:
                if self.engine_use_ray:
841
842
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
843
844
845
846
847
848
849
850
851
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
852

853
854
        if aborted_requests:
            await self._engine_abort(aborted_requests)
855

Zhuohan Li's avatar
Zhuohan Li committed
856
        if self.engine_use_ray:
857
            request_outputs = await self.engine.step.remote()  # type: ignore
858
        else:
859
            request_outputs = await self.engine.step_async(virtual_engine)
860

Antoni Baum's avatar
Antoni Baum committed
861
        # Put the outputs into the corresponding streams.
862
        finished = True
863
        for request_output in request_outputs:
864
            self._request_tracker.process_request_output(
865
                request_output, verbose=self.log_requests)
866
            finished = finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
867

868
        return not finished
869

Antoni Baum's avatar
Antoni Baum committed
870
871
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
872
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
873
874
875
876
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
877
878
879
880
881
882
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
883
        while True:
884
            if not any(has_requests_in_progress):
885
                logger.debug("Waiting for new requests...")
886
887
888
889
890
891
892
893
894
895
896
897
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
898
                await self._request_tracker.wait_for_new_requests()
899
                logger.debug("Got new requests!")
900
901
902
903
904
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
905
906
907
908

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
909
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
936
937
938
939
940
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
941
942
            await asyncio.sleep(0)

943
944
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
Antoni Baum's avatar
Antoni Baum committed
945
946
947
    async def add_request(
        self,
        request_id: str,
948
        inputs: PromptInputs,
949
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
950
        arrival_time: Optional[float] = None,
951
        lora_request: Optional[LoRARequest] = None,
952
        trace_headers: Optional[Mapping[str, str]] = None,
953
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
954
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
955
        if not self.is_running:
956
957
958
959
960
961
962
963
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
964

965
        stream = self._request_tracker.add_request(
966
            request_id,
967
            verbose=self.log_requests,
968
            inputs=inputs,
969
            params=params,
970
            arrival_time=arrival_time or time.time(),
971
            lora_request=lora_request,
972
            trace_headers=trace_headers,
973
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
974

975
        return stream.generator()
976

977
    async def generate(
978
        self,
979
        inputs: PromptInputs,
980
981
        sampling_params: SamplingParams,
        request_id: str,
982
        lora_request: Optional[LoRARequest] = None,
983
        trace_headers: Optional[Mapping[str, str]] = None,
984
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
985
    ) -> AsyncGenerator[RequestOutput, None]:
986
987
988
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
989
990
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
991
992

        Args:
993
994
995
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
996
997
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
998
            lora_request: LoRA request to use for generation, if any.
999
            trace_headers: OpenTelemetry trace headers.
1000
1001
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
1002
1003

        Yields:
1004
1005
            The output `RequestOutput` objects from the LLMEngine
            for the request.
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1049
        """
1050
        async for output in await self.add_request(
1051
                request_id,
1052
                inputs,
1053
                sampling_params,
1054
                lora_request=lora_request,
1055
                trace_headers=trace_headers,
1056
                prompt_adapter_request=prompt_adapter_request,
1057
        ):
1058
            yield LLMEngine.validate_output(output, RequestOutput)
1059
1060
1061

    async def encode(
        self,
1062
        inputs: PromptInputs,
1063
1064
1065
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1066
        trace_headers: Optional[Mapping[str, str]] = None,
1067
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
1068
1069
1070
1071
1072
1073
1074
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1075
1076
1077
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
1078
1079
1080
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1081
            trace_headers: OpenTelemetry trace headers.
1082
1083

        Yields:
1084
            The output `EmbeddingRequestOutput` objects from the LLMEngine
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1128
        async for output in await self.add_request(
1129
                request_id,
1130
                inputs,
1131
                pooling_params,
1132
                lora_request=lora_request,
1133
                trace_headers=trace_headers,
1134
        ):
1135
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1136

Antoni Baum's avatar
Antoni Baum committed
1137
1138
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1139

Antoni Baum's avatar
Antoni Baum committed
1140
1141
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1142

Antoni Baum's avatar
Antoni Baum committed
1143
1144
1145
        Args:
            request_id: The unique id of the request.
        """
1146
1147
1148
1149
1150
1151
1152
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1153
        return self._abort(request_id)
1154

Antoni Baum's avatar
Antoni Baum committed
1155
    def _abort(self, request_id: str) -> None:
1156
1157
1158
1159
1160
1161
1162
1163
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1164
        self._request_tracker.abort_request(request_id,
1165
                                            exception=asyncio.CancelledError,
1166
                                            verbose=self.log_requests)
1167

1168
1169
1170
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
1171
            return await self.engine.get_model_config.remote()  # type: ignore
1172
1173
1174
        else:
            return self.engine.get_model_config()

1175
1176
1177
1178
1179
1180
1181
1182
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_parallel_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_parallel_config()

1183
1184
1185
1186
1187
1188
1189
1190
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_scheduler_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_scheduler_config()

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_lora_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_lora_config()

1207
1208
1209
1210
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1211
        if self.engine_use_ray:
1212
1213
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
1214
1215
        else:
            self.engine.do_log_stats()
1216

1217
    async def check_health(self) -> None:
1218
1219
1220
1221
1222
1223
1224
1225
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
1226
                await self.engine.check_health.remote()  # type: ignore
1227
1228
1229
1230
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
1231
        logger.debug("Health check took %fs", time.perf_counter() - t)
1232
1233
1234
1235
1236
1237
1238

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.add_logger.remote(  # type: ignore
                    logger_name=logger_name, logger=logger))
        else:
            self.engine.add_logger(logger_name=logger_name, logger=logger)

    def remove_logger(self, logger_name: str) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.remove_logger.remote(  # type: ignore
                    logger_name=logger_name))
        else:
            self.engine.remove_logger(logger_name=logger_name)
1255
1256
1257
1258
1259
1260

    async def start_profile(self) -> None:
        self.engine.model_executor._run_workers("start_profile")

    async def stop_profile(self) -> None:
        self.engine.model_executor._run_workers("stop_profile")