async_llm_engine.py 53 KB
Newer Older
1
2
import asyncio
import time
3
import weakref
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
7
from weakref import ReferenceType
8

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
15
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
16
from vllm.engine.metrics_types import StatLoggerBase
17
from vllm.entrypoints.llm import BeamSearchSequence
18
from vllm.executor.executor_base import ExecutorAsyncBase
19
from vllm.executor.gpu_executor import GPUExecutorAsync
20
from vllm.executor.ray_utils import initialize_ray_cluster
21
from vllm.inputs import PromptType, TokensPrompt
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.logger import init_logger
23
from vllm.lora.request import LoRARequest
24
25
from vllm.model_executor.guided_decoding import (
    get_guided_decoding_logits_processor)
26
from vllm.model_executor.layers.sampler import SamplerOutput
27
28
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
                          RequestOutput)
29
from vllm.pooling_params import PoolingParams
30
from vllm.prompt_adapter.request import PromptAdapterRequest
31
from vllm.sampling_params import BeamSearchParams, SamplingParams
32
from vllm.sequence import ExecuteModelRequest
33
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
34
from vllm.usage.usage_lib import UsageContext
35
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
36
                        get_beam_search_score, random_uuid, weak_bind)
37
38

logger = init_logger(__name__)
39
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
40

Antoni Baum's avatar
Antoni Baum committed
41

42
43
44
45
class AsyncEngineDeadError(RuntimeError):
    pass


46
47
48
49
50
51
52
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
53
54

    exception = None
55
    try:
56
57
58
59
60
61
62
63
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
64
65
66
67
68
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
69
            "Task finished unexpectedly. This should never happen! "
70
            "Please open an issue on Github. See stack trace above for the "
71
            "actual cause.") from e
72
73


74
75
76
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
77
class AsyncStream:
78
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
79
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
80

81
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
82
        self.request_id = request_id
83
        self._cancel = cancel
84
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
85
86
        self._finished = False

87
88
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
89
90
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
91

92
93
94
95
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
96
97
98
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
99
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
100
101
102
103
104

    @property
    def finished(self) -> bool:
        return self._finished

105
106
107
108
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
109
            while True:
110
                result = await self._queue.get()
111
                if self._is_raisable(result):
112
113
114
115
116
117
118
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
119

120
121
122
123
124
125
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
126

127
128
129
130
131
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
132
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
133
134
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
135
        self.new_requests_event = asyncio.Event()
136
137
138
139

    def __contains__(self, item):
        return item in self._request_streams

140
141
    def __len__(self) -> int:
        return len(self._request_streams)
142
143
144
145
146
147
148

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
149
            self.abort_request(request_id, exception=exc)
150
        else:
151
            # NB: tuple() used here because self.abort_request pops the stream
152
            # out of self._request_streams, so we can't iterate on it directly
153
154
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
155
156

    def process_request_output(self,
157
158
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
159
160
161
162
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
163
        finished = request_output.finished
164

165
166
167
168
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
169
170
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
171
        if stream is not None:
172
            stream.put(request_output)
173
174
175
176
177
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
178

179
180
    def process_exception(self,
                          request_id: str,
181
                          exception: BaseException,
182
183
184
185
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
186
            logger.info("Finished request %s.", request_id)
187
        self.abort_request(request_id, exception=exception)
188

189
190
191
192
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
193
194
195
196
197
198
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

199
200
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
201
202
203
204
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
205
206
207

        self.new_requests_event.set()

208
209
210
        if verbose:
            logger.info("Added request %s.", request_id)

211
212
        return stream

213
214
215
    def abort_request(self,
                      request_id: str,
                      *,
216
217
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
218
                      verbose: bool = False) -> None:
219
220
        """Abort a request during next background loop iteration."""
        if verbose:
221
            logger.info("Aborted request %s.", request_id)
222

223
        self._aborted_requests.put_nowait(request_id)
224

225
226
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
227
            stream.finish(exception=exception)
228

229
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
230
231
        """Get the new requests and finished requests to be
        sent to the engine."""
232
        new_requests: List[Dict] = []
233
234
        finished_requests: Set[str] = set()

235
236
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
237
238
239
240
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
241
242
            request_id = stream.request_id
            if request_id in finished_requests:
243
                # The request has already been aborted.
244
245
246
247
248
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
249
250

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
251

252
    async def wait_for_new_requests(self):
253
254
255
256
257
258
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
259

Antoni Baum's avatar
Antoni Baum committed
260
261
262
263

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

264
265
266
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

267
    async def step_async(
268
269
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
270
271
272
273
274
275
276
277
278
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
279
280
281
282
283
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
284
285
        allow_async_output_proc = cached_outputs.allow_async_output_proc

286
287
        ctx = self.scheduler_contexts[virtual_engine]

288
289
290
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

291
292
293
294
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
295

296
            # Schedule iteration
297
298
299
300
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

301
302
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
303
304

            # Maybe switch from async mode to sync mode
305
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
306
                self._process_model_outputs(ctx=ctx)
307

308
309
310
311
312
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
313
314
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
315
316
317

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
318

319
        if not scheduler_outputs.is_empty():
320
321
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
322
323
324
325
326
327
328
329

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

330
331
332
333
334
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
335
                virtual_engine=virtual_engine,
336
337
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
338
339
340
341
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
342
343

            if allow_async_output_proc:
344
345
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
346

347
            # Execute the model.
348
            outputs = await self.model_executor.execute_model_async(
349
                execute_model_req)
350

351
352
353
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
354
                self._update_cached_scheduler_output(virtual_engine, outputs)
355
        else:
356
357
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
358
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
359

360
361
362
363
364
365
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
366
            # Clear the cache if we have finished all the steps
367
368
369
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
370

371
372
373
374
375
376
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

377
378
379
380
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
381
382
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
383

384
            if outputs and allow_async_output_proc:
385
                assert len(
386
                    outputs
387
388
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
389
                    outputs[0], seq_group_metadata_list,
390
                    scheduler_outputs.scheduled_seq_groups)
391
392

            if not allow_async_output_proc:
393
                self._process_model_outputs(ctx=ctx)
394
395

                # Log stats.
396
                self.do_log_stats(scheduler_outputs, outputs)
397
398
399
400
401

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
402
            # Multi-step case
403
            return ctx.request_outputs
404
405
406
407

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
408
                self._process_model_outputs(ctx=ctx)
409
            assert len(ctx.output_queue) == 0
410

411
        return ctx.request_outputs
412

413
414
415
416
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

417
    @overload  # DEPRECATED
418
    async def add_request_async(
419
420
        self,
        request_id: str,
421
422
        *,
        inputs: PromptType,
423
424
425
426
427
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
428
        priority: int = 0,
429
430
431
432
433
434
435
436
437
438
439
440
441
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
442
        priority: int = 0,
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
459
            priority: int = 0,
460
461
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
462
    ) -> None:
463
        """Async version of :meth:`add_request`."""
464
465
466
467
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

468
469
470
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
471
472
473
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
474
475
        if arrival_time is None:
            arrival_time = time.time()
476

477
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
478
            prompt,
479
480
            request_id=request_id,
            lora_request=lora_request,
481
482
            prompt_adapter_request=prompt_adapter_request,
        )
483
        processed_inputs = self.input_processor(preprocessed_inputs)
484

485
486
487
488
489
490
491
492
493
494
495
496
        if isinstance(params, SamplingParams) and \
            params.guided_decoding is not None:
            # Guided decoding has an async implementation for building logits
            # processors in a separate threadpool.
            # We want to invoke that here instead of using the blocking
            # implementation in the LLMEngine
            params = await build_guided_decoding_logits_processor_async(
                sampling_params=params,
                tokenizer=self.get_tokenizer(lora_request),
                default_guided_backend=self.decoding_config.
                guided_decoding_backend)

497
        self._add_processed_request(
498
            request_id=request_id,
499
500
501
502
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
503
            prompt_adapter_request=prompt_adapter_request,
504
            trace_headers=trace_headers,
505
            priority=priority,
506
        )
507

508
    async def check_health_async(self) -> None:
509
510
        if self.tokenizer:
            self.tokenizer.check_health()
511
        self.model_executor.check_health()
512

513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
async def build_guided_decoding_logits_processor_async(
        sampling_params: SamplingParams, tokenizer: AnyTokenizer,
        default_guided_backend: str) -> SamplingParams:
    """Constructs logits processors based on the guided_decoding,
    logits_bias, and allowed_token_ids fields in sampling_params. Deletes
    those fields and adds the constructed logits processors to the
    logits_processors field. Modifies sampling params in-place and returns
    the modified sampling params."""
    if (guided_decoding := sampling_params.guided_decoding) is None:
        return sampling_params

    logger.debug("Building guided decoding logits processor. "
                 "Params: %s", guided_decoding)

    guided_decoding.backend = guided_decoding.backend or default_guided_backend

    processor = await get_guided_decoding_logits_processor(
        guided_params=guided_decoding, tokenizer=tokenizer)

    if processor:
        if sampling_params.logits_processors is None:
            sampling_params.logits_processors = []
        sampling_params.logits_processors.append(processor)

    # Unset guided decoding params after constructing the lp from them
    sampling_params.guided_decoding = None

    return sampling_params


544
class AsyncLLMEngine:
545
    """An asynchronous wrapper for :class:`LLMEngine`.
546

547
548
549
550
551
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
552
553

    Args:
554
        log_requests: Whether to log the requests.
555
556
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
557
558
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
559
    """
560

Antoni Baum's avatar
Antoni Baum committed
561
562
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

563
564
565
    def __init__(self,
                 *args,
                 log_requests: bool = True,
566
                 start_engine_loop: bool = True,
567
                 **kwargs) -> None:
568
        self.log_requests = log_requests
569
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
570

571
572
573
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
574
575
576
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

577
578
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
579
                weak_bind(self.process_request_outputs)
580

581
        self.background_loop: Optional[asyncio.Future] = None
582
583
584
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
585
        self._background_loop_unshielded: Optional[asyncio.Task] = None
586
        self.start_engine_loop = start_engine_loop
587
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
588

589
590
591
        # Lazy initialized fields
        self._request_tracker: RequestTracker

592
593
594
595
596
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

597
    @classmethod
598
599
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
600
601
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
602
603
604
605
606
607
608
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
609
610
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
611
        elif engine_config.device_config.device_type == "tpu":
612
613
614
615
616
617
618
            if distributed_executor_backend == "ray":
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
619
620
621
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
622
623
624
625
626
627
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
628
629
630
631
632
633
634
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
635
636
637
638
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
639
640
641
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
642
        elif distributed_executor_backend == "ray":
643
644
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
645
646
647
648
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
649
650
651
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
652
653
654
655
656
657
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
658
        engine_config: Optional[EngineConfig] = None,
659
660
661
662
663
664
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
665
666
        if engine_config is None:
            engine_config = engine_args.create_engine_config()
667
668
669

        executor_class = cls._get_executor_cls(engine_config)

670
671
672
        if executor_class.uses_ray:
            initialize_ray_cluster(engine_config.parallel_config)

673
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
674
        engine = cls(
675
676
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
677
678
679
680
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
681
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
682
        )
683
684
        return engine

685
686
    @property
    def is_running(self) -> bool:
687
        return (self.background_loop is not None
688
                and self._background_loop_unshielded is not None
689
690
691
692
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
693
694
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
695
696
697
698
699
700
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

701
    @property
702
703
704
705
706
707
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
708

709
710
711
712
713
714
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
715

716
717
718
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
719
    ) -> AnyTokenizer:
720
721
        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
722

723
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
724
        """Start the background loop."""
725
726
727
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
728
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
729
            raise RuntimeError("Background loop is already running.")
730
731
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
732
733

        self._background_loop_unshielded = asyncio.get_event_loop(
734
        ).create_task(self.run_engine_loop(weakref.ref(self)))
735
        self._background_loop_unshielded.add_done_callback(
736
            partial(_log_task_completion, error_callback=self._error_callback))
737
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
738

739
740
741
742
743
744
745
746
747
748
749
750
751
752
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

753
    async def engine_step(self, virtual_engine: int) -> bool:
754
755
756
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
757

758
759
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
760
761
762

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
763
            try:
764
                await self.engine.add_request_async(**new_request)
765
766
767
768
769
770
771
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
772

773
774
        if aborted_requests:
            await self._engine_abort(aborted_requests)
775

776
        request_outputs = await self.engine.step_async(virtual_engine)
777

Antoni Baum's avatar
Antoni Baum committed
778
        # Put the outputs into the corresponding streams.
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
794
        for request_output in request_outputs:
795
            self._request_tracker.process_request_output(
796
                request_output, verbose=self.log_requests)
797
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
798

799
        return all_finished
800

Antoni Baum's avatar
Antoni Baum committed
801
    async def _engine_abort(self, request_ids: Iterable[str]):
802
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
803

804
805
806
807
808
809
810
811
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
        engine: Optional["AsyncLLMEngine"] = engine_ref()
        if not engine:
            return

812
        pipeline_parallel_size = \
813
                engine.engine.parallel_config.pipeline_parallel_size
814
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
815
        while True:
816
            if not any(has_requests_in_progress):
817
                logger.debug("Waiting for new requests...")
818
819
820
821
822
823
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
824
825
826
827
828
829
830
831
832
833
834
835
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
836
                logger.debug("Got new requests!")
837
                requests_in_progress = [
838
                    asyncio.create_task(engine.engine_step(ve))
839
840
841
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
842
843
844
845

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
846
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
847
848
849
850
851
852
853
854
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
855
                    has_unfinished_requests = (
856
857
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
858
                            virtual_engine))
859
860
861
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
862
                                engine.engine_step(virtual_engine)))
863
864
865
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
866
867
868
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
869
                engine.set_errored(exc)
870
                raise
Antoni Baum's avatar
Antoni Baum committed
871
872
            await asyncio.sleep(0)

873
874
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
875
876
    @overload  # DEPRECATED
    def add_request(
877
878
        self,
        request_id: str,
879
880
        *,
        inputs: PromptType,
881
        params: Union[SamplingParams, PoolingParams],
882
883
884
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
885
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
886
        priority: int = 0,
887
888
889
890
891
892
893
894
895
896
897
898
899
900
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
901
        priority: int = 0,
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
919
        priority: int = 0,
920
921
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
922
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
923
924
925
926
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

927
        if not self.is_running:
928
929
930
931
932
933
934
935
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
936

937
938
939
940
941
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

942
        stream = self._request_tracker.add_request(
943
            request_id,
944
            verbose=self.log_requests,
945
            prompt=prompt,
946
            params=params,
947
            arrival_time=arrival_time or time.time(),
948
            lora_request=lora_request,
949
            trace_headers=trace_headers,
950
951
952
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
953

954
        return stream.generator()
955

956
    async def generate(
957
        self,
958
        prompt: PromptType,
959
960
        sampling_params: SamplingParams,
        request_id: str,
961
        lora_request: Optional[LoRARequest] = None,
962
        trace_headers: Optional[Mapping[str, str]] = None,
963
964
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
965
    ) -> AsyncGenerator[RequestOutput, None]:
966
967
968
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
969
970
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
971
972

        Args:
973
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
974
                for more details about the format of each input.
975
976
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
977
            lora_request: LoRA request to use for generation, if any.
978
            trace_headers: OpenTelemetry trace headers.
979
            prompt_adapter_request: Prompt Adapter request to use
980
                                            for generation, if any.
981
982
            priority: The priority of the request.
                Only applicable with priority scheduling.
983
984

        Yields:
985
986
            The output `RequestOutput` objects from the LLMEngine
            for the request.
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
1030
        """
1031
        async for output in await self.add_request(
1032
                request_id,
1033
                prompt,
1034
                sampling_params,
1035
                lora_request=lora_request,
1036
                trace_headers=trace_headers,
1037
                prompt_adapter_request=prompt_adapter_request,
1038
                priority=priority,
1039
        ):
1040
            yield LLMEngine.validate_output(output, RequestOutput)
1041

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
    async def beam_search(
        self,
        prompt: Union[PromptType, List[int]],
        request_id: str,
        params: BeamSearchParams,
    ) -> AsyncGenerator[RequestOutput, None]:

        beam_width = params.beam_width
        max_tokens = params.max_tokens
        ignore_eos = params.ignore_eos
        temperature = params.temperature
1053
1054
1055
1056
1057
1058
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111

        tokenizer = await self.get_tokenizer()
        tokenizedPrompt = prompt if isinstance(
            prompt, list) else tokenizer.encode(prompt)
        tokenizedLength = len(tokenizedPrompt)

        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
                                            temperature=temperature)
        all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
        completed = []

        for _ in range(max_tokens):
            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            tasks = []

            request_id = f"beam_search-{random_uuid()}"
            for i, individual_prompt in enumerate(prompts_batch):
                request_id_item = f"{request_id}-{i}"
                task = asyncio.create_task(
                    collect_from_async_generator(
                        self.generate(individual_prompt, beam_search_params,
                                      request_id_item)))
                tasks.append(task)

            output = await asyncio.gather(*tasks)

            output = [x[0] for x in output]

            logger.info(output)

            new_beams = []
            for i, current_beam in enumerate(all_beams):
                result = output[i]

                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    for token_id, logprob_obj in logprobs.items():
                        new_beam = BeamSearchSequence(
                            tokens=current_beam.tokens + [token_id],
                            cum_logprob=current_beam.cum_logprob +
                            logprob_obj.logprob)

                        if token_id == tokenizer.eos_token_id and \
                            not ignore_eos:
                            completed.append(new_beam)
                        else:
                            new_beams.append(new_beam)

1112
            sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
1113
1114
1115
            all_beams = sorted_beams[:beam_width]

        completed.extend(all_beams)
1116
        sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
        best_beams = sorted_completed[:beam_width]

        for beam in best_beams:
            beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])

        beam_search_output = RequestOutput(
            request_id=request_id,
            prompt=prompt,
            outputs=[
                CompletionOutput(
                    text=beam.text,
                    cumulative_logprob=beam.cum_logprob,
                    token_ids=beam.tokens,
                    index=i,
                    logprobs=beam.cum_logprob,
                ) for (i, beam) in enumerate(best_beams)
            ],
            finished=True,
            prompt_token_ids=tokenizedPrompt,
            prompt_logprobs=None)

        yield LLMEngine.validate_output(beam_search_output, RequestOutput)

1140
1141
    async def encode(
        self,
1142
        prompt: PromptType,
1143
1144
1145
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1146
        trace_headers: Optional[Mapping[str, str]] = None,
1147
        priority: int = 0,
1148
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
1149
1150
1151
1152
1153
1154
1155
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1156
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1157
                for more details about the format of each input.
1158
1159
1160
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1161
            trace_headers: OpenTelemetry trace headers.
1162
1163
            priority: The priority of the request.
                Only applicable with priority scheduling.
1164
1165

        Yields:
1166
            The output `EmbeddingRequestOutput` objects from the LLMEngine
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1210
        async for output in await self.add_request(
1211
                request_id,
1212
                prompt,
1213
                pooling_params,
1214
                lora_request=lora_request,
1215
                trace_headers=trace_headers,
1216
                priority=priority,
1217
        ):
1218
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1219

Antoni Baum's avatar
Antoni Baum committed
1220
1221
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1222

Antoni Baum's avatar
Antoni Baum committed
1223
1224
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1225

Antoni Baum's avatar
Antoni Baum committed
1226
1227
1228
        Args:
            request_id: The unique id of the request.
        """
1229
1230
1231
1232
1233
1234
1235
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1236
        return self._abort(request_id)
1237

Antoni Baum's avatar
Antoni Baum committed
1238
    def _abort(self, request_id: str) -> None:
1239
1240
1241
1242
1243
1244
1245
1246
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1247
        self._request_tracker.abort_request(request_id,
1248
                                            exception=asyncio.CancelledError,
1249
                                            verbose=self.log_requests)
1250

1251
1252
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1253
        return self.engine.get_model_config()
1254

1255
1256
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1257
        return self.engine.get_parallel_config()
1258

1259
1260
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1261
        return self.engine.get_decoding_config()
1262

1263
1264
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1265
        return self.engine.get_scheduler_config()
1266
1267
1268

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1269
        return self.engine.get_lora_config()
1270

1271
1272
1273
1274
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1275
        self.engine.do_log_stats()
1276

1277
    async def check_health(self) -> None:
1278
1279
1280
1281
1282
1283
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1284
        await self.engine.check_health_async()
1285
        logger.debug("Health check took %fs", time.perf_counter() - t)
1286
1287

    async def is_tracing_enabled(self) -> bool:
1288
        return self.engine.is_tracing_enabled()
1289
1290

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1291
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1292
1293

    def remove_logger(self, logger_name: str) -> None:
1294
        self.engine.remove_logger(logger_name=logger_name)
1295
1296

    async def start_profile(self) -> None:
1297
1298
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1299
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1300
1301
1302
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")
1303
1304

    async def stop_profile(self) -> None:
1305
1306
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1307
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1308
1309
1310
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")