async_llm_engine.py 47 KB
Newer Older
1
2
import asyncio
import time
3
import weakref
Antoni Baum's avatar
Antoni Baum committed
4
from functools import partial
5
6
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
                    List, Mapping, Optional, Set, Tuple, Type, Union, overload)
7
from weakref import ReferenceType
8

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
15
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
16
from vllm.engine.metrics_types import StatLoggerBase
17
from vllm.executor.executor_base import ExecutorAsyncBase
18
from vllm.executor.gpu_executor import GPUExecutorAsync
19
from vllm.executor.ray_utils import initialize_ray_cluster
20
from vllm.inputs import PromptType
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
23
from vllm.model_executor.layers.sampler import SamplerOutput
24
25
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
26
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.sampling_params import SamplingParams
28
from vllm.sequence import ExecuteModelRequest
29
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
30
from vllm.usage.usage_lib import UsageContext
31
from vllm.utils import deprecate_kwargs, weak_bind
32
33

logger = init_logger(__name__)
34
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
35

Antoni Baum's avatar
Antoni Baum committed
36

37
38
39
40
class AsyncEngineDeadError(RuntimeError):
    pass


41
42
43
44
45
46
47
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
48
49

    exception = None
50
    try:
51
52
53
54
55
56
57
58
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
59
60
61
62
63
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
64
            "Task finished unexpectedly. This should never happen! "
65
            "Please open an issue on Github. See stack trace above for the "
66
            "actual cause.") from e
67
68


69
70
71
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
72
class AsyncStream:
73
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
74
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
75

76
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
77
        self.request_id = request_id
78
        self._cancel = cancel
79
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
80
81
        self._finished = False

82
83
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
84
85
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
86

87
88
89
90
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
91
92
93
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
94
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
95
96
97
98
99

    @property
    def finished(self) -> bool:
        return self._finished

100
101
102
103
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
104
            while True:
105
                result = await self._queue.get()
106
                if self._is_raisable(result):
107
108
109
110
111
112
113
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
114

115
116
117
118
119
120
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
121

122
123
124
125
126
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
127
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
128
129
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
130
        self.new_requests_event = asyncio.Event()
131
132
133
134

    def __contains__(self, item):
        return item in self._request_streams

135
136
    def __len__(self) -> int:
        return len(self._request_streams)
137
138
139
140
141
142
143

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
144
            self.abort_request(request_id, exception=exc)
145
        else:
146
            # NB: tuple() used here because self.abort_request pops the stream
147
            # out of self._request_streams, so we can't iterate on it directly
148
149
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
150
151

    def process_request_output(self,
152
153
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
154
155
156
157
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
158
        finished = request_output.finished
159

160
161
162
163
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
164
165
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
166
        if stream is not None:
167
            stream.put(request_output)
168
169
170
171
172
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
173

174
175
    def process_exception(self,
                          request_id: str,
176
                          exception: BaseException,
177
178
179
180
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
181
            logger.info("Finished request %s.", request_id)
182
        self.abort_request(request_id, exception=exception)
183

184
185
186
187
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
188
189
190
191
192
193
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

194
195
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
196
197
198
199
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
200
201
202

        self.new_requests_event.set()

203
204
205
        if verbose:
            logger.info("Added request %s.", request_id)

206
207
        return stream

208
209
210
    def abort_request(self,
                      request_id: str,
                      *,
211
212
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
213
                      verbose: bool = False) -> None:
214
215
        """Abort a request during next background loop iteration."""
        if verbose:
216
            logger.info("Aborted request %s.", request_id)
217

218
        self._aborted_requests.put_nowait(request_id)
219

220
221
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
222
            stream.finish(exception=exception)
223

224
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
225
226
        """Get the new requests and finished requests to be
        sent to the engine."""
227
        new_requests: List[Dict] = []
228
229
        finished_requests: Set[str] = set()

230
231
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
232
233
234
235
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
236
237
            request_id = stream.request_id
            if request_id in finished_requests:
238
                # The request has already been aborted.
239
240
241
242
243
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
244
245

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
246

247
    async def wait_for_new_requests(self):
248
249
250
251
252
253
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
254

Antoni Baum's avatar
Antoni Baum committed
255
256
257
258

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

259
260
261
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

262
    async def step_async(
263
264
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
265
266
267
268
269
270
271
272
273
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
274
275
276
277
278
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
279
280
        allow_async_output_proc = cached_outputs.allow_async_output_proc

281
282
        ctx = self.scheduler_contexts[virtual_engine]

283
284
285
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

286
287
288
289
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
290

291
            # Schedule iteration
292
293
294
295
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

296
297
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
298
299

            # Maybe switch from async mode to sync mode
300
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
301
                self._process_model_outputs(ctx=ctx)
302

303
304
305
306
307
            if (self.scheduler_config.is_multi_step
                    and scheduler_outputs.num_lookahead_slots > 0):
                # cache the scheduler outputs for the next iteration if we have
                # lookahead slots
                self._cache_scheduler_outputs_for_multi_step(
308
309
                    virtual_engine, seq_group_metadata_list, scheduler_outputs,
                    allow_async_output_proc)
310
311
312

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
313

314
        if not scheduler_outputs.is_empty():
315
316
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
317
318
319
320
321
322
323
324

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

325
326
327
328
329
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
330
                virtual_engine=virtual_engine,
331
332
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
333
334
335
336
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
337
338

            if allow_async_output_proc:
339
340
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
341

342
            # Execute the model.
343
            outputs = await self.model_executor.execute_model_async(
344
                execute_model_req)
345

346
347
348
            # we need to do this here so that last step's sampled_token_ids can
            # be passed to the next iteration for PP.
            if self.scheduler_config.is_multi_step:
349
                self._update_cached_scheduler_output(virtual_engine, outputs)
350
        else:
351
352
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
353
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
354

355
356
357
358
359
360
        # Finish the current step for all the sequence groups.
        if self.scheduler_config.is_multi_step:
            for seq_group in seq_group_metadata_list:
                seq_group.finish_step()

        if not self._has_remaining_steps(seq_group_metadata_list):
361
            # Clear the cache if we have finished all the steps
362
363
364
            if self.scheduler_config.is_multi_step:
                self.cached_scheduler_outputs[
                    virtual_engine] = SchedulerOutputState()
Antoni Baum's avatar
Antoni Baum committed
365

366
367
368
369
370
371
            # is_first_step_output is True only when the num_steps of all
            # the sequences are 1. When the num_steps > 1,
            # multi_step_model_runner does the first-step output append.
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

372
373
374
375
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
376
377
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
378

379
            if outputs and allow_async_output_proc:
380
                assert len(
381
                    outputs
382
383
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
384
                    outputs[0], seq_group_metadata_list,
385
                    scheduler_outputs.scheduled_seq_groups)
386
387

            if not allow_async_output_proc:
388
                self._process_model_outputs(ctx=ctx)
389
390

                # Log stats.
391
                self.do_log_stats(scheduler_outputs, outputs)
392
393
394
395
396

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
397
            # Multi-step case
398
            return ctx.request_outputs
399
400
401
402

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
403
                self._process_model_outputs(ctx=ctx)
404
            assert len(ctx.output_queue) == 0
405

406
        return ctx.request_outputs
407

408
409
410
411
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

412
    @overload  # DEPRECATED
413
    async def add_request_async(
414
415
        self,
        request_id: str,
416
417
        *,
        inputs: PromptType,
418
419
420
421
422
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
423
        priority: int = 0,
424
425
426
427
428
429
430
431
432
433
434
435
436
    ) -> None:
        ...

    @overload
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
437
        priority: int = 0,
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    ) -> None:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request_async(
            self,
            request_id: str,
            prompt: Optional[PromptType] = None,
            params: Optional[Union[SamplingParams, PoolingParams]] = None,
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Mapping[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,
454
            priority: int = 0,
455
456
            *,
            inputs: Optional[PromptType] = None,  # DEPRECATED
457
    ) -> None:
458
        """Async version of :meth:`add_request`."""
459
460
461
462
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

463
464
465
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
466
467
468
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
469
470
        if arrival_time is None:
            arrival_time = time.time()
471

472
        preprocessed_inputs = await self.input_preprocessor.preprocess_async(
473
            prompt,
474
475
            request_id=request_id,
            lora_request=lora_request,
476
477
            prompt_adapter_request=prompt_adapter_request,
        )
478
        processed_inputs = self.input_processor(preprocessed_inputs)
479
480

        self._add_processed_request(
481
            request_id=request_id,
482
483
484
485
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
486
            prompt_adapter_request=prompt_adapter_request,
487
            trace_headers=trace_headers,
488
            priority=priority,
489
        )
490

491
    async def check_health_async(self) -> None:
492
493
        if self.tokenizer:
            self.tokenizer.check_health()
494
        self.model_executor.check_health()
495

496

497
class AsyncLLMEngine:
498
    """An asynchronous wrapper for :class:`LLMEngine`.
499

500
501
502
503
504
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
505
506

    Args:
507
        log_requests: Whether to log the requests.
508
509
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
510
511
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
512
    """
513

Antoni Baum's avatar
Antoni Baum committed
514
515
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

516
517
518
    def __init__(self,
                 *args,
                 log_requests: bool = True,
519
                 start_engine_loop: bool = True,
520
                 **kwargs) -> None:
521
        self.log_requests = log_requests
522
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
523

524
525
526
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
527
528
529
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

530
531
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
532
                weak_bind(self.process_request_outputs)
533

534
        self.background_loop: Optional[asyncio.Future] = None
535
536
537
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
538
        self._background_loop_unshielded: Optional[asyncio.Task] = None
539
        self.start_engine_loop = start_engine_loop
540
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
541

542
543
544
        # Lazy initialized fields
        self._request_tracker: RequestTracker

545
546
547
548
549
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

550
    @classmethod
551
552
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
553
554
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
555
556
557
558
559
560
561
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
562
563
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
564
        elif engine_config.device_config.device_type == "tpu":
565
566
567
568
569
570
571
            if distributed_executor_backend == "ray":
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
572
573
574
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
575
576
577
578
579
580
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
581
582
583
584
585
586
587
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
588
589
590
591
            elif distributed_executor_backend == "mp":
                from vllm.executor.multiproc_xpu_executor import (
                    MultiprocessingXPUExecutorAsync)
                executor_class = MultiprocessingXPUExecutorAsync
592
593
594
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
595
        elif distributed_executor_backend == "ray":
596
597
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
598
599
600
601
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
602
603
604
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
605
606
607
608
609
610
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
611
        engine_config: Optional[EngineConfig] = None,
612
613
614
615
616
617
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
618
619
        if engine_config is None:
            engine_config = engine_args.create_engine_config()
620
621
622

        executor_class = cls._get_executor_cls(engine_config)

623
624
625
        if executor_class.uses_ray:
            initialize_ray_cluster(engine_config.parallel_config)

626
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
627
        engine = cls(
628
629
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
630
631
632
633
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
634
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
635
        )
636
637
        return engine

638
639
    @property
    def is_running(self) -> bool:
640
        return (self.background_loop is not None
641
                and self._background_loop_unshielded is not None
642
643
644
645
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
646
647
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
648
649
650
651
652
653
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

654
    @property
655
656
657
658
659
660
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
661

662
663
664
665
666
667
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
668

669
670
671
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
672
    ) -> AnyTokenizer:
673
674
        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
675

676
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
677
        """Start the background loop."""
678
679
680
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
681
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
682
            raise RuntimeError("Background loop is already running.")
683
684
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
685
686

        self._background_loop_unshielded = asyncio.get_event_loop(
687
        ).create_task(self.run_engine_loop(weakref.ref(self)))
688
        self._background_loop_unshielded.add_done_callback(
689
            partial(_log_task_completion, error_callback=self._error_callback))
690
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
691

692
693
694
695
696
697
698
699
700
701
702
703
704
705
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

706
    async def engine_step(self, virtual_engine: int) -> bool:
707
708
709
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
710

711
712
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
713
714
715

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
716
            try:
717
                await self.engine.add_request_async(**new_request)
718
719
720
721
722
723
724
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
725

726
727
        if aborted_requests:
            await self._engine_abort(aborted_requests)
728

729
        request_outputs = await self.engine.step_async(virtual_engine)
730

Antoni Baum's avatar
Antoni Baum committed
731
        # Put the outputs into the corresponding streams.
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
747
        for request_output in request_outputs:
748
            self._request_tracker.process_request_output(
749
                request_output, verbose=self.log_requests)
750
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
751

752
        return all_finished
753

Antoni Baum's avatar
Antoni Baum committed
754
    async def _engine_abort(self, request_ids: Iterable[str]):
755
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
756

757
758
759
760
761
762
763
764
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
        engine: Optional["AsyncLLMEngine"] = engine_ref()
        if not engine:
            return

765
        pipeline_parallel_size = \
766
                engine.engine.parallel_config.pipeline_parallel_size
767
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
768
        while True:
769
            if not any(has_requests_in_progress):
770
                logger.debug("Waiting for new requests...")
771
772
773
774
775
776
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
777
778
779
780
781
782
783
784
785
786
787
788
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
789
                logger.debug("Got new requests!")
790
                requests_in_progress = [
791
                    asyncio.create_task(engine.engine_step(ve))
792
793
794
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
795
796
797
798

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
799
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
800
801
802
803
804
805
806
807
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
808
                    has_unfinished_requests = (
809
810
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
811
                            virtual_engine))
812
813
814
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
815
                                engine.engine_step(virtual_engine)))
816
817
818
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
819
820
821
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
822
                engine.set_errored(exc)
823
                raise
Antoni Baum's avatar
Antoni Baum committed
824
825
            await asyncio.sleep(0)

826
827
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
828
829
    @overload  # DEPRECATED
    def add_request(
830
831
        self,
        request_id: str,
832
833
        *,
        inputs: PromptType,
834
        params: Union[SamplingParams, PoolingParams],
835
836
837
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
838
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
839
        priority: int = 0,
840
841
842
843
844
845
846
847
848
849
850
851
852
853
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @overload
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
854
        priority: int = 0,
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
    ) -> Coroutine[None, None, AsyncGenerator[Union[
            RequestOutput, EmbeddingRequestOutput], None]]:
        ...

    @deprecate_kwargs(
        "inputs",
        additional_message="Please use the 'prompt' parameter instead.",
    )
    async def add_request(
        self,
        request_id: str,
        prompt: Optional[PromptType] = None,
        params: Optional[Union[SamplingParams, PoolingParams]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
872
        priority: int = 0,
873
874
        *,
        inputs: Optional[PromptType] = None,  # DEPRECATED
875
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
876
877
878
879
        if inputs is not None:
            prompt = inputs
        assert prompt is not None and params is not None

880
        if not self.is_running:
881
882
883
884
885
886
887
888
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
889

890
891
892
893
894
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

895
        stream = self._request_tracker.add_request(
896
            request_id,
897
            verbose=self.log_requests,
898
            prompt=prompt,
899
            params=params,
900
            arrival_time=arrival_time or time.time(),
901
            lora_request=lora_request,
902
            trace_headers=trace_headers,
903
904
905
            prompt_adapter_request=prompt_adapter_request,
            priority=priority,
        )
Antoni Baum's avatar
Antoni Baum committed
906

907
        return stream.generator()
908

909
    async def generate(
910
        self,
911
        prompt: PromptType,
912
913
        sampling_params: SamplingParams,
        request_id: str,
914
        lora_request: Optional[LoRARequest] = None,
915
        trace_headers: Optional[Mapping[str, str]] = None,
916
917
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
918
    ) -> AsyncGenerator[RequestOutput, None]:
919
920
921
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
922
923
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
924
925

        Args:
926
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
927
                for more details about the format of each input.
928
929
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
930
            lora_request: LoRA request to use for generation, if any.
931
            trace_headers: OpenTelemetry trace headers.
932
            prompt_adapter_request: Prompt Adapter request to use
933
                                            for generation, if any.
934
935
            priority: The priority of the request.
                Only applicable with priority scheduling.
936
937

        Yields:
938
939
            The output `RequestOutput` objects from the LLMEngine
            for the request.
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
983
        """
984
        async for output in await self.add_request(
985
                request_id,
986
                prompt,
987
                sampling_params,
988
                lora_request=lora_request,
989
                trace_headers=trace_headers,
990
                prompt_adapter_request=prompt_adapter_request,
991
                priority=priority,
992
        ):
993
            yield LLMEngine.validate_output(output, RequestOutput)
994
995
996

    async def encode(
        self,
997
        prompt: PromptType,
998
999
1000
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
1001
        trace_headers: Optional[Mapping[str, str]] = None,
1002
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
1003
1004
1005
1006
1007
1008
1009
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
1010
            prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
1011
                for more details about the format of each input.
1012
1013
1014
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
1015
            trace_headers: OpenTelemetry trace headers.
1016
1017

        Yields:
1018
            The output `EmbeddingRequestOutput` objects from the LLMEngine
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
1062
        async for output in await self.add_request(
1063
                request_id,
1064
                prompt,
1065
                pooling_params,
1066
                lora_request=lora_request,
1067
                trace_headers=trace_headers,
1068
        ):
1069
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
1070

Antoni Baum's avatar
Antoni Baum committed
1071
1072
    async def abort(self, request_id: str) -> None:
        """Abort a request.
1073

Antoni Baum's avatar
Antoni Baum committed
1074
1075
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
1076

Antoni Baum's avatar
Antoni Baum committed
1077
1078
1079
        Args:
            request_id: The unique id of the request.
        """
1080
1081
1082
1083
1084
1085
1086
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1087
        return self._abort(request_id)
1088

Antoni Baum's avatar
Antoni Baum committed
1089
    def _abort(self, request_id: str) -> None:
1090
1091
1092
1093
1094
1095
1096
1097
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1098
        self._request_tracker.abort_request(request_id,
1099
                                            exception=asyncio.CancelledError,
1100
                                            verbose=self.log_requests)
1101

1102
1103
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
1104
        return self.engine.get_model_config()
1105

1106
1107
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
1108
        return self.engine.get_parallel_config()
1109

1110
1111
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
1112
        return self.engine.get_decoding_config()
1113

1114
1115
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
1116
        return self.engine.get_scheduler_config()
1117
1118
1119

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
1120
        return self.engine.get_lora_config()
1121

1122
1123
1124
1125
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1126
        self.engine.do_log_stats()
1127

1128
    async def check_health(self) -> None:
1129
1130
1131
1132
1133
1134
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

1135
        await self.engine.check_health_async()
1136
        logger.debug("Health check took %fs", time.perf_counter() - t)
1137
1138

    async def is_tracing_enabled(self) -> bool:
1139
        return self.engine.is_tracing_enabled()
1140
1141

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
1142
        self.engine.add_logger(logger_name=logger_name, logger=logger)
1143
1144

    def remove_logger(self, logger_name: str) -> None:
1145
        self.engine.remove_logger(logger_name=logger_name)
1146
1147

    async def start_profile(self) -> None:
1148
1149
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1150
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1151
1152
1153
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")
1154
1155

    async def stop_profile(self) -> None:
1156
1157
        # using type instead of isinstance to check to avoid capturing
        # inherited classes
1158
        if type(self.engine.model_executor) == GPUExecutorAsync:  # noqa: E721
1159
1160
1161
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")