async_llm_engine.py 52.3 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
6

7
from typing_extensions import assert_never
8

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
15
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
16
                                    PromptComponents, SchedulerOutputState)
17
from vllm.engine.metrics_types import StatLoggerBase
18
from vllm.executor.executor_base import ExecutorAsyncBase
19
from vllm.executor.ray_utils import initialize_ray_cluster, ray
20
21
22
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
                         SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.logger import init_logger
24
from vllm.lora.request import LoRARequest
25
from vllm.model_executor.layers.sampler import SamplerOutput
26
27
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
28
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.sampling_params import SamplingParams
30
from vllm.sequence import ExecuteModelRequest
31
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
32
from vllm.usage.usage_lib import UsageContext
33
from vllm.utils import print_warning_once
34
35

logger = init_logger(__name__)
36
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
37

Antoni Baum's avatar
Antoni Baum committed
38

39
40
41
42
class AsyncEngineDeadError(RuntimeError):
    pass


43
44
45
46
47
48
49
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
50
51

    exception = None
52
    try:
53
54
55
56
57
58
59
60
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
61
62
63
64
65
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
66
            "Task finished unexpectedly. This should never happen! "
67
            "Please open an issue on Github. See stack trace above for the "
68
            "actual cause.") from e
69
70


71
72
73
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
74
class AsyncStream:
75
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
76
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
77

78
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
79
        self.request_id = request_id
80
        self._cancel = cancel
81
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
82
83
        self._finished = False

84
85
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
86
87
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
88

89
90
91
92
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
93
94
95
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
96
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
97
98
99
100
101

    @property
    def finished(self) -> bool:
        return self._finished

102
103
104
105
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
106
            while True:
107
                result = await self._queue.get()
108
                if self._is_raisable(result):
109
110
111
112
113
114
115
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
116

117
118
119
120
121
122
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
123

124
125
126
127
128
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
129
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
130
131
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
132
        self.new_requests_event = asyncio.Event()
133
134
135
136

    def __contains__(self, item):
        return item in self._request_streams

137
138
    def __len__(self) -> int:
        return len(self._request_streams)
139
140
141
142
143
144
145

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
146
            self.abort_request(request_id, exception=exc)
147
        else:
148
            # NB: tuple() used here because self.abort_request pops the stream
149
            # out of self._request_streams, so we can't iterate on it directly
150
151
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
152
153

    def process_request_output(self,
154
155
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
156
157
158
159
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
160
        finished = request_output.finished
161

162
163
164
165
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
166
167
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
168
        if stream is not None:
169
            stream.put(request_output)
170
171
172
173
174
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
175

176
177
    def process_exception(self,
                          request_id: str,
178
                          exception: BaseException,
179
180
181
182
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
183
            logger.info("Finished request %s.", request_id)
184
        self.abort_request(request_id, exception=exception)
185

186
187
188
189
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
190
191
192
193
194
195
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

196
197
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
198
199
200
201
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
202
203
204

        self.new_requests_event.set()

205
206
207
        if verbose:
            logger.info("Added request %s.", request_id)

208
209
        return stream

210
211
212
    def abort_request(self,
                      request_id: str,
                      *,
213
214
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
215
                      verbose: bool = False) -> None:
216
217
        """Abort a request during next background loop iteration."""
        if verbose:
218
            logger.info("Aborted request %s.", request_id)
219

220
        self._aborted_requests.put_nowait(request_id)
221

222
223
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
224
            stream.finish(exception=exception)
225

226
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
227
228
        """Get the new requests and finished requests to be
        sent to the engine."""
229
        new_requests: List[Dict] = []
230
231
        finished_requests: Set[str] = set()

232
233
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
234
235
236
237
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
238
239
            request_id = stream.request_id
            if request_id in finished_requests:
240
                # The request has already been aborted.
241
242
243
244
245
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
246
247

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
248

249
    async def wait_for_new_requests(self):
250
251
252
253
254
255
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
256

Antoni Baum's avatar
Antoni Baum committed
257
258
259
260

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

261
262
263
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

264
    async def step_async(
265
266
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
267
268
269
270
271
272
273
274
275
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
276
277
278
279
280
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
281
282
        allow_async_output_proc = cached_outputs.allow_async_output_proc

283
284
285
286
        # Detect async + multi-step
        use_async_and_multi_step = (self.scheduler_config.is_multi_step
                                    and allow_async_output_proc)

287
288
        ctx = self.scheduler_contexts[virtual_engine]

289
290
291
292
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
293
294
295
296

            # Clear outputs on scheduler iteration start
            ctx.request_outputs.clear()

297
            # Schedule iteration
298
299
300
301
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

302
303
304
305
306
            # Detect async + multi-step
            use_async_and_multi_step = (self.scheduler_config.is_multi_step
                                        and allow_async_output_proc)

            # Maybe switch from async mode to sync mode
307
308
309
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
                self._process_model_outputs(virtual_engine=virtual_engine,
                                            is_async=True)
310

311
312
313
314
315
316
317
            # For async + multi-step, init the queue
            if use_async_and_multi_step:
                assert len(ctx.output_queue) == 0
                assert seq_group_metadata_list is not None
                ctx.output_queue.append(
                    (None, seq_group_metadata_list, scheduler_outputs))

318
319
320
321
322
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
323
324
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
325
326
327

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
328

329
        if not scheduler_outputs.is_empty():
330
331
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
332
333
334
335
336
337
338
339

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

340
341
342
343
344
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
345
                virtual_engine=virtual_engine,
346
347
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
348
349
350
351
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
352
353

            if allow_async_output_proc:
354
355
356
357
358
359
360
                async_callback = self.async_callback_multi_step[
                    virtual_engine] if use_async_and_multi_step \
                    else self.async_callback[virtual_engine]

                execute_model_req.async_callback = async_callback
                execute_model_req.use_async_and_multi_step = \
                    use_async_and_multi_step
361

362
            # Execute the model.
363
            output = await self.model_executor.execute_model_async(
364
                execute_model_req)
365
366
367
368
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
                self._update_cached_scheduler_output(virtual_engine, output)
369
        else:
370
            if not use_async_and_multi_step and len(ctx.output_queue) > 0:
371
                assert not self.scheduler_config.is_multi_step
372
373
                self._process_model_outputs(virtual_engine=virtual_engine,
                                            is_async=True)
374
            output = []
Antoni Baum's avatar
Antoni Baum committed
375

376
377
378
379
380
381
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
382
            # Clear the cache if we have finished all the steps
383
384
385
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
386

387
388
389
390
391
392
            if use_async_and_multi_step:
                # For async + multi-step, clear the queue
                ctx.output_queue.clear()
            else:
                ctx.output_queue.append(
                    (output, seq_group_metadata_list, scheduler_outputs))
393

394
395
396
397
398
399
400
                if output and allow_async_output_proc:
                    assert len(
                        output
                    ) == 1, "Multi step decoding does not work with async output processing."  # noqa: E501
                    self._advance_to_next_step(
                        output[0], seq_group_metadata_list,
                        scheduler_outputs.scheduled_seq_groups)
401
402

            if not allow_async_output_proc:
403
404
                self._process_model_outputs(virtual_engine=virtual_engine,
                                            is_async=False)
405
406
407
408
409
410
411
412

                # Log stats.
                self.do_log_stats(scheduler_outputs, output)

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
413
414
415
416
417
            # Multi-step case
            if use_async_and_multi_step:
                return []
            else:
                ctx.request_outputs = []
418
419
420
421
422
423
424
425

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
                assert not self.scheduler_config.is_multi_step
                self._process_model_outputs(virtual_engine=virtual_engine,
                                            is_async=True)
            assert len(ctx.output_queue) == 0
426

427
        return ctx.request_outputs
428

429
430
431
432
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

433
434
435
436
437
438
439
    async def _tokenize_prompt_async(
        self,
        prompt: str,
        request_id: str,
        lora_request: Optional[LoRARequest],
    ) -> List[int]:
        """Async version of :meth:`_tokenize_prompt`."""
440
441
        tokenizer = self.get_tokenizer_group(
            missing_msg="prompts must be None if skip_tokenizer_init is True")
442
443
444
445
446
447

        return await tokenizer.encode_async(request_id=request_id,
                                            prompt=prompt,
                                            lora_request=lora_request)

    async def _extract_prompt_components_async(
448
        self,
449
        inputs: SingletonPromptInputs,
450
        request_id: str,
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        lora_request: Optional[LoRARequest] = None,
    ) -> PromptComponents:
        """Async version of :meth:`_extract_prompt_components`."""
        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = await self._tokenize_prompt_async(
                prompt,
                request_id=request_id,
                lora_request=lora_request,
            )
            multi_modal_data = None
        elif isinstance(inputs, dict):
            if "prompt_token_ids" in inputs:
                prompt = None
                prompt_token_ids = inputs["prompt_token_ids"]
            else:
                # NOTE: This extra assignment is required to pass mypy
                prompt = parsed_prompt = inputs["prompt"]
                prompt_token_ids = await self._tokenize_prompt_async(
                    parsed_prompt,
                    request_id=request_id,
                    lora_request=lora_request,
                )

            multi_modal_data = inputs.get("multi_modal_data")
        else:
            assert_never(inputs)

        return prompt, prompt_token_ids, multi_modal_data

    async def _process_encoder_decoder_prompt_async(
        self,
483
        inputs: PromptInputs,
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
        request_id: str,
    ) -> EncoderDecoderLLMInputs:
        """Async version of :meth:`_process_encoder_decoder_prompt`."""
        encoder_comps: PromptComponents
        decoder_comps: DecoderPromptComponents

        if is_explicit_encoder_decoder_prompt(inputs):
            encoder_task = self._extract_prompt_components_async(
                inputs["encoder_prompt"],
                request_id=request_id,
            )

            if (decoder_input := inputs["decoder_prompt"]) is None:
                encoder_comps = await encoder_task
                decoder_comps = None, None, None
            else:
                decoder_task = self._extract_prompt_components_async(
                    decoder_input,
                    request_id=request_id,
                )

                encoder_comps, decoder_comps = await asyncio.gather(
                    encoder_task, decoder_task)
        else:
            encoder_comps = await self._extract_prompt_components_async(
                inputs,
                request_id=request_id,
            )

            decoder_comps = None, None, None

        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

    async def _process_decoder_only_prompt_async(
        self,
        inputs: SingletonPromptInputs,
        request_id: str,
521
        lora_request: Optional[LoRARequest] = None,
522
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
523
    ) -> LLMInputs:
524
525
526
527
528
529
        """Async version of :meth:`_process_decoder_only_prompt`."""
        prompt_comps = await self._extract_prompt_components_async(
            inputs,
            request_id=request_id,
            lora_request=lora_request,
        )
530

531
532
533
534
        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )
535

536
537
538
539
540
541
542
543
544
545
546
547
548
    async def process_model_inputs_async(
        self,
        inputs: PromptInputs,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
        """Async version of :meth:`process_model_inputs`."""
        if self.is_encoder_decoder_model():
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            model_inputs = await self._process_encoder_decoder_prompt_async(
                inputs,
549
                request_id=request_id,
550
            )
551
        else:
552
553
554
            if is_explicit_encoder_decoder_prompt(inputs):
                raise ValueError("Cannot pass encoder-decoder prompt "
                                 "to decoder-only models")
555

556
557
558
559
560
561
562
            # Decoder-only operation
            model_inputs = await self._process_decoder_only_prompt_async(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )
563

564
        return self.input_processor(model_inputs)
565
566

    async def add_request_async(
567
568
569
570
571
572
573
574
        self,
        request_id: str,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
575
    ) -> None:
576
        """Async version of :meth:`add_request`."""
577
578
579
580
581
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
582
583

        processed_inputs = await self.process_model_inputs_async(
584
            inputs,
585
586
            request_id=request_id,
            lora_request=lora_request,
587
588
            prompt_adapter_request=prompt_adapter_request,
        )
589
590

        self._add_processed_request(
591
            request_id=request_id,
592
593
594
595
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
596
            prompt_adapter_request=prompt_adapter_request,
597
            trace_headers=trace_headers,
598
        )
599

600
    async def check_health_async(self) -> None:
601
602
        if self.tokenizer:
            self.tokenizer.check_health()
603
        self.model_executor.check_health()
604

605

606
class AsyncLLMEngine:
607
    """An asynchronous wrapper for :class:`LLMEngine`.
608

609
610
611
612
613
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
614
615
616
617
618

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
619
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
620
621
            async frontend will be executed in a separate process as the
            model workers.
622
        log_requests: Whether to log the requests.
623
624
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
625
626
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
627
    """
628

Antoni Baum's avatar
Antoni Baum committed
629
630
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

631
632
633
634
635
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
636
                 start_engine_loop: bool = True,
637
                 **kwargs) -> None:
638
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
639
        self.engine_use_ray = engine_use_ray
640
        self.log_requests = log_requests
Antoni Baum's avatar
Antoni Baum committed
641
642
        self.engine = self._init_engine(*args, **kwargs)

643
644
645
646
647
648
649
650
651
652
653
654
655
656
        if self.engine_use_ray:
            print_warning_once(
                "DEPRECATED. `--engine-use-ray` is deprecated and will "
                "be removed in a future update. "
                "See https://github.com/vllm-project/vllm/issues/7045.")

            if envs.VLLM_ALLOW_ENGINE_USE_RAY:
                print_warning_once(
                    "VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray")
            else:
                raise ValueError("`--engine-use-ray` is deprecated. "
                                 "Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
                                 "force use it")

657
        self.background_loop: Optional[asyncio.Future] = None
658
659
660
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
661
        self._background_loop_unshielded: Optional[asyncio.Task] = None
662
        self.start_engine_loop = start_engine_loop
663
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
664

665
666
667
        # Lazy initialized fields
        self._request_tracker: RequestTracker

668
    @classmethod
669
670
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
671
672
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
673
674
675
676
677
678
679
680
681
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
682
683
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
684
        elif engine_config.device_config.device_type == "tpu":
685
686
687
688
689
690
691
692
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
693
694
695
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
696
697
698
699
700
701
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
702
703
704
705
706
707
708
709
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
710
711
712
713
714
            elif distributed_executor_backend == "mp":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
715
716
717
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
718
        elif distributed_executor_backend == "ray":
719
            initialize_ray_cluster(engine_config.parallel_config)
720
721
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
722
723
724
725
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
726
727
728
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        if engine_args.engine_use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

        executor_class = cls._get_executor_cls(engine_config)

749
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
750
        engine = cls(
751
            executor_class.uses_ray,
yhu422's avatar
yhu422 committed
752
            engine_args.engine_use_ray,
753
754
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
755
756
757
758
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
759
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
760
        )
761
762
        return engine

763
764
    @property
    def is_running(self) -> bool:
765
        return (self.background_loop is not None
766
                and self._background_loop_unshielded is not None
767
768
769
770
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
771
772
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
773
774
775
776
777
778
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

779
780
781
782
783
    @property
    def limit_concurrency(self) -> Optional[int]:
        """Maximum number of concurrently running requests."""
        return None

784
785
786
787
788
789
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
790

791
792
793
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
794
    ) -> AnyTokenizer:
795
        if self.engine_use_ray:
796
797
798
799
800
            return await self.engine.get_tokenizer.remote(  # type: ignore
                lora_request)

        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
801

802
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
803
        """Start the background loop."""
804
805
806
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
807
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
808
            raise RuntimeError("Background loop is already running.")
809
810
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
811
812
813
814

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
815
            partial(_log_task_completion, error_callback=self._error_callback))
816
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
817

818
819
820
821
822
823
824
825
826
827
828
829
830
831
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

Antoni Baum's avatar
Antoni Baum committed
832
833
    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
834
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
835
            engine_class = self._engine_class
836
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
837
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
838
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
839
840
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
841
842
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
843
844
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
845
846
847
848
849
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
850
851
        return engine_class(*args, **kwargs)

852
    async def engine_step(self, virtual_engine: int) -> bool:
853
854
855
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
856

857
858
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
859
860
861
862

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
863
864
            try:
                if self.engine_use_ray:
865
866
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
867
868
869
870
871
872
873
874
875
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
876

877
878
        if aborted_requests:
            await self._engine_abort(aborted_requests)
879

Zhuohan Li's avatar
Zhuohan Li committed
880
        if self.engine_use_ray:
881
            request_outputs = await self.engine.step.remote()  # type: ignore
882
        else:
883
            request_outputs = await self.engine.step_async(virtual_engine)
884

Antoni Baum's avatar
Antoni Baum committed
885
        # Put the outputs into the corresponding streams.
886
        finished = True
887
        for request_output in request_outputs:
888
            self._request_tracker.process_request_output(
889
                request_output, verbose=self.log_requests)
890
            finished = finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
891

892
        return not finished
893

Antoni Baum's avatar
Antoni Baum committed
894
895
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
896
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
897
898
899
900
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
901
902
903
904
905
906
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
907
        while True:
908
            if not any(has_requests_in_progress):
909
                logger.debug("Waiting for new requests...")
910
911
912
913
914
915
916
917
918
919
920
921
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
922
                await self._request_tracker.wait_for_new_requests()
923
                logger.debug("Got new requests!")
924
925
926
927
928
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
929
930
931
932

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
933
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
960
961
962
963
964
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
965
966
            await asyncio.sleep(0)

967
968
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
Antoni Baum's avatar
Antoni Baum committed
969
970
971
    async def add_request(
        self,
        request_id: str,
972
        inputs: PromptInputs,
973
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
974
        arrival_time: Optional[float] = None,
975
        lora_request: Optional[LoRARequest] = None,
976
        trace_headers: Optional[Mapping[str, str]] = None,
977
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
978
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
979
        if not self.is_running:
980
981
982
983
984
985
986
987
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
988

989
        stream = self._request_tracker.add_request(
990
            request_id,
991
            verbose=self.log_requests,
992
            inputs=inputs,
993
            params=params,
994
            arrival_time=arrival_time or time.time(),
995
            lora_request=lora_request,
996
            trace_headers=trace_headers,
997
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
998

999
        return stream.generator()
1000

1001
    async def generate(
1002
        self,
1003
        inputs: PromptInputs,
1004
1005
        sampling_params: SamplingParams,
        request_id: str,
1006
        lora_request: Optional[LoRARequest] = None,
1007
        trace_headers: Optional[Mapping[str, str]] = None,
1008
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
1009
    ) -> AsyncGenerator[RequestOutput, None]:
1010
1011
1012
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
1013
1014
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
1015
1016

        Args:
1017
1018
1019
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
1020
1021
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
1022
            lora_request: LoRA request to use for generation, if any.
1023
            trace_headers: OpenTelemetry trace headers.
1024
1025
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
1026
1027

        Yields:
1028
1029
            The output `RequestOutput` objects from the LLMEngine
            for the request.
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1073
        """
1074
        async for output in await self.add_request(
1075
                request_id,
1076
                inputs,
1077
                sampling_params,
1078
                lora_request=lora_request,
1079
                trace_headers=trace_headers,
1080
                prompt_adapter_request=prompt_adapter_request,
1081
        ):
1082
            yield LLMEngine.validate_output(output, RequestOutput)
1083
1084
1085

    async def encode(
        self,
1086
        inputs: PromptInputs,
1087
1088
1089
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1090
        trace_headers: Optional[Mapping[str, str]] = None,
1091
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
1092
1093
1094
1095
1096
1097
1098
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1099
1100
1101
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
1102
1103
1104
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1105
            trace_headers: OpenTelemetry trace headers.
1106
1107

        Yields:
1108
            The output `EmbeddingRequestOutput` objects from the LLMEngine
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1152
        async for output in await self.add_request(
1153
                request_id,
1154
                inputs,
1155
                pooling_params,
1156
                lora_request=lora_request,
1157
                trace_headers=trace_headers,
1158
        ):
1159
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1160

Antoni Baum's avatar
Antoni Baum committed
1161
1162
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1163

Antoni Baum's avatar
Antoni Baum committed
1164
1165
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1166

Antoni Baum's avatar
Antoni Baum committed
1167
1168
1169
        Args:
            request_id: The unique id of the request.
        """
1170
1171
1172
1173
1174
1175
1176
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1177
        return self._abort(request_id)
1178

Antoni Baum's avatar
Antoni Baum committed
1179
    def _abort(self, request_id: str) -> None:
1180
1181
1182
1183
1184
1185
1186
1187
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1188
        self._request_tracker.abort_request(request_id,
1189
                                            exception=asyncio.CancelledError,
1190
                                            verbose=self.log_requests)
1191

1192
1193
1194
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
1195
            return await self.engine.get_model_config.remote()  # type: ignore
1196
1197
1198
        else:
            return self.engine.get_model_config()

1199
1200
1201
1202
1203
1204
1205
1206
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_parallel_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_parallel_config()

1207
1208
1209
1210
1211
1212
1213
1214
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_scheduler_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_scheduler_config()

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_lora_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_lora_config()

1231
1232
1233
1234
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1235
        if self.engine_use_ray:
1236
1237
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
1238
1239
        else:
            self.engine.do_log_stats()
1240

1241
    async def check_health(self) -> None:
1242
1243
1244
1245
1246
1247
1248
1249
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
1250
                await self.engine.check_health.remote()  # type: ignore
1251
1252
1253
1254
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
1255
        logger.debug("Health check took %fs", time.perf_counter() - t)
1256
1257
1258
1259
1260
1261
1262

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.add_logger.remote(  # type: ignore
                    logger_name=logger_name, logger=logger))
        else:
            self.engine.add_logger(logger_name=logger_name, logger=logger)

    def remove_logger(self, logger_name: str) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.remove_logger.remote(  # type: ignore
                    logger_name=logger_name))
        else:
            self.engine.remove_logger(logger_name=logger_name)
1279
1280
1281
1282
1283
1284

    async def start_profile(self) -> None:
        self.engine.model_executor._run_workers("start_profile")

    async def stop_profile(self) -> None:
        self.engine.model_executor._run_workers("stop_profile")