async_llm_engine.py 44.4 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
5
                    Optional, Set, Tuple, Type, Union)
6

7
from transformers import PreTrainedTokenizer
8
from typing_extensions import assert_never
9

10
import vllm.envs as envs
11
12
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
13
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.engine.arg_utils import AsyncEngineArgs
15
from vllm.engine.async_timeout import asyncio_timeout
16
17
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
                                    PromptComponents)
18
from vllm.engine.metrics import StatLoggerBase
19
from vllm.executor.executor_base import ExecutorAsyncBase
20
from vllm.executor.ray_utils import initialize_ray_cluster, ray
21
22
23
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
                         SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.logger import init_logger
25
from vllm.lora.request import LoRARequest
26
27
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
28
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.sampling_params import SamplingParams
30
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
31
from vllm.usage.usage_lib import UsageContext
32
33

logger = init_logger(__name__)
34
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
35

Antoni Baum's avatar
Antoni Baum committed
36

37
38
39
40
class AsyncEngineDeadError(RuntimeError):
    pass


41
42
43
44
45
46
47
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
48
49

    exception = None
50
    try:
51
52
53
54
55
56
57
58
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
59
60
61
62
63
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
64
            "Task finished unexpectedly. This should never happen! "
65
            "Please open an issue on Github. See stack trace above for the "
66
            "actual cause.") from e
67
68


69
70
71
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
72
class AsyncStream:
73
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
74
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
75

76
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
77
        self.request_id = request_id
78
        self._cancel = cancel
79
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
80
81
        self._finished = False

82
83
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
84
85
86
87
        if self._finished:
            return
        self._queue.put_nowait(item)

88
89
90
91
92
    def finish(self, cancelled: bool = False) -> None:
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
                asyncio.CancelledError if cancelled else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
93
94
95
96
97

    @property
    def finished(self) -> bool:
        return self._finished

98
99
100
101
102
103
104
105
106
107
108
109
110
111
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
            while not self._finished:
                result = await self._queue.get()
                if isinstance(result, Exception):
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
112
113


114
115
116
117
118
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
119
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
120
121
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
122
        self.new_requests_event = asyncio.Event()
123
124
125
126

    def __contains__(self, item):
        return item in self._request_streams

127
128
    def __len__(self) -> int:
        return len(self._request_streams)
129
130
131
132
133
134
135
136

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
137
            self.abort_request(request_id)
138
        else:
139
140
141
            # NB: list() used here because self.abort_request pops the stream
            # out of self._request_streams, so we can't iterate on it directly
            for rid, stream in list(self._request_streams.items()):
142
                stream.put(exc)
143
                self.abort_request(rid)
144
145

    def process_request_output(self,
146
147
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
148
149
150
151
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
152
        finished = request_output.finished
153

154
155
156
157
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
158
159
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
160
        if stream is not None:
161
            stream.put(request_output)
162
163
164
165
166
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
167

168
169
170
171
172
173
174
175
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
176
            logger.info("Finished request %s.", request_id)
177
178
        self.abort_request(request_id)

179
180
181
182
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
183
184
185
186
187
188
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

189
190
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
191
192
193
194
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
195
196
197

        self.new_requests_event.set()

198
199
200
        if verbose:
            logger.info("Added request %s.", request_id)

201
202
        return stream

203
204
205
206
207
    def abort_request(self,
                      request_id: str,
                      *,
                      cancelled: bool = False,
                      verbose: bool = False) -> None:
208
209
        """Abort a request during next background loop iteration."""
        if verbose:
210
            logger.info("Aborted request %s.", request_id)
211

212
        self._aborted_requests.put_nowait(request_id)
213

214
215
216
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
            stream.finish(cancelled=cancelled)
217

218
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
219
220
        """Get the new requests and finished requests to be
        sent to the engine."""
221
        new_requests: List[Dict] = []
222
223
        finished_requests: Set[str] = set()

224
225
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
226
227
228
229
230
231
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
232
                stream.finish(cancelled=True)
233
234
235
236
237
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
238

239
    async def wait_for_new_requests(self):
240
241
242
243
244
245
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
246

Antoni Baum's avatar
Antoni Baum committed
247
248
249
250

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

251
    async def step_async(
252
253
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
254
255
256
257
258
259
260
261
262
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
263
264
        seq_group_metadata_list, scheduler_outputs = self.scheduler[
            virtual_engine].schedule()
Antoni Baum's avatar
Antoni Baum committed
265

266
267
        if not scheduler_outputs.is_empty():
            # Execute the model.
268
269
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
270
271
272
273
274
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
275
                virtual_engine=virtual_engine,
276
277
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
Mor Zusman's avatar
Mor Zusman committed
278
                finished_requests_ids=finished_requests_ids)
279
            output = await self.model_executor.execute_model_async(
280
                execute_model_req)
281
282
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
283

284
        request_outputs = self._process_model_outputs(
285
            output, scheduler_outputs.scheduled_seq_groups,
286
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
287

288
        # Log stats.
289
        self.do_log_stats(scheduler_outputs, output)
290

291
292
293
        # Tracing
        self.do_tracing(scheduler_outputs)

294
295
        return request_outputs

296
297
298
299
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    async def _tokenize_prompt_async(
        self,
        prompt: str,
        request_id: str,
        lora_request: Optional[LoRARequest],
    ) -> List[int]:
        """Async version of :meth:`_tokenize_prompt`."""
        tokenizer = self.get_tokenizer_group("prompts must be None if "
                                             "skip_tokenizer_init is True")

        return await tokenizer.encode_async(request_id=request_id,
                                            prompt=prompt,
                                            lora_request=lora_request)

    async def _extract_prompt_components_async(
315
        self,
316
        inputs: SingletonPromptInputs,
317
        request_id: str,
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        lora_request: Optional[LoRARequest] = None,
    ) -> PromptComponents:
        """Async version of :meth:`_extract_prompt_components`."""
        if isinstance(inputs, str):
            prompt = inputs
            prompt_token_ids = await self._tokenize_prompt_async(
                prompt,
                request_id=request_id,
                lora_request=lora_request,
            )
            multi_modal_data = None
        elif isinstance(inputs, dict):
            if "prompt_token_ids" in inputs:
                prompt = None
                prompt_token_ids = inputs["prompt_token_ids"]
            else:
                # NOTE: This extra assignment is required to pass mypy
                prompt = parsed_prompt = inputs["prompt"]
                prompt_token_ids = await self._tokenize_prompt_async(
                    parsed_prompt,
                    request_id=request_id,
                    lora_request=lora_request,
                )

            multi_modal_data = inputs.get("multi_modal_data")
        else:
            assert_never(inputs)

        return prompt, prompt_token_ids, multi_modal_data

    async def _process_encoder_decoder_prompt_async(
        self,
350
        inputs: PromptInputs,
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        request_id: str,
    ) -> EncoderDecoderLLMInputs:
        """Async version of :meth:`_process_encoder_decoder_prompt`."""
        encoder_comps: PromptComponents
        decoder_comps: DecoderPromptComponents

        if is_explicit_encoder_decoder_prompt(inputs):
            encoder_task = self._extract_prompt_components_async(
                inputs["encoder_prompt"],
                request_id=request_id,
            )

            if (decoder_input := inputs["decoder_prompt"]) is None:
                encoder_comps = await encoder_task
                decoder_comps = None, None, None
            else:
                decoder_task = self._extract_prompt_components_async(
                    decoder_input,
                    request_id=request_id,
                )

                encoder_comps, decoder_comps = await asyncio.gather(
                    encoder_task, decoder_task)
        else:
            encoder_comps = await self._extract_prompt_components_async(
                inputs,
                request_id=request_id,
            )

            decoder_comps = None, None, None

        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)

    async def _process_decoder_only_prompt_async(
        self,
        inputs: SingletonPromptInputs,
        request_id: str,
388
        lora_request: Optional[LoRARequest] = None,
389
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
390
    ) -> LLMInputs:
391
392
393
394
395
396
        """Async version of :meth:`_process_decoder_only_prompt`."""
        prompt_comps = await self._extract_prompt_components_async(
            inputs,
            request_id=request_id,
            lora_request=lora_request,
        )
397

398
399
400
401
        return self._build_decoder_only_llm_inputs(
            prompt_comps,
            prompt_adapter_request=prompt_adapter_request,
        )
402

403
404
405
406
407
408
409
410
411
412
413
414
415
    async def process_model_inputs_async(
        self,
        inputs: PromptInputs,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
        """Async version of :meth:`process_model_inputs`."""
        if self.is_encoder_decoder_model():
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder
            model_inputs = await self._process_encoder_decoder_prompt_async(
                inputs,
416
                request_id=request_id,
417
            )
418
        else:
419
420
421
            if is_explicit_encoder_decoder_prompt(inputs):
                raise ValueError("Cannot pass encoder-decoder prompt "
                                 "to decoder-only models")
422

423
424
425
426
427
428
429
            # Decoder-only operation
            model_inputs = await self._process_decoder_only_prompt_async(
                inputs,
                request_id=request_id,
                lora_request=lora_request,
                prompt_adapter_request=prompt_adapter_request,
            )
430

431
        return self.input_processor(model_inputs)
432
433

    async def add_request_async(
434
435
436
437
438
439
440
441
        self,
        request_id: str,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
442
    ) -> None:
443
        """Async version of :meth:`add_request`."""
444
445
446
447
448
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
449
450

        processed_inputs = await self.process_model_inputs_async(
451
            inputs,
452
453
            request_id=request_id,
            lora_request=lora_request,
454
455
            prompt_adapter_request=prompt_adapter_request,
        )
456
457

        self._add_processed_request(
458
            request_id=request_id,
459
460
461
462
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
463
            prompt_adapter_request=prompt_adapter_request,
464
            trace_headers=trace_headers,
465
        )
466

467
    async def check_health_async(self) -> None:
468
469
        if self.tokenizer:
            self.tokenizer.check_health()
470
        self.model_executor.check_health()
471

472

473
class AsyncLLMEngine:
474
    """An asynchronous wrapper for :class:`LLMEngine`.
475

476
477
478
479
480
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
481
482
483
484
485

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
486
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
487
488
            async frontend will be executed in a separate process as the
            model workers.
489
        log_requests: Whether to log the requests.
490
491
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
492
493
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
494
    """
495

Antoni Baum's avatar
Antoni Baum committed
496
497
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

498
499
500
501
502
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
503
                 start_engine_loop: bool = True,
504
                 **kwargs) -> None:
505
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
506
        self.engine_use_ray = engine_use_ray
507
        self.log_requests = log_requests
Antoni Baum's avatar
Antoni Baum committed
508
509
        self.engine = self._init_engine(*args, **kwargs)

510
        self.background_loop: Optional[asyncio.Future] = None
511
512
513
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
514
        self._background_loop_unshielded: Optional[asyncio.Task] = None
515
        self.start_engine_loop = start_engine_loop
516
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
517

518
519
520
        # Lazy initialized fields
        self._request_tracker: RequestTracker

521
    @classmethod
522
523
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
524
525
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
526
527
528
529
530
531
532
533
534
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
535
536
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
537
        elif engine_config.device_config.device_type == "tpu":
538
539
540
541
542
543
544
545
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
546
547
548
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
549
550
551
552
553
554
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
555
556
557
558
559
560
561
562
563
564
565
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
566
        elif distributed_executor_backend == "ray":
567
            initialize_ray_cluster(engine_config.parallel_config)
568
569
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
570
571
572
573
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
574
575
576
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        if engine_args.engine_use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

        executor_class = cls._get_executor_cls(engine_config)

597
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
598
        engine = cls(
599
            executor_class.uses_ray,
yhu422's avatar
yhu422 committed
600
            engine_args.engine_use_ray,
601
602
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
603
604
605
606
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
607
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
608
        )
609
610
        return engine

611
612
    @property
    def is_running(self) -> bool:
613
        return (self.background_loop is not None
614
                and self._background_loop_unshielded is not None
615
616
617
618
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
619
620
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
621
622
623
624
625
626
627
628
629
630
631
632
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
633

634
635
636
637
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> "PreTrainedTokenizer":
638
        if self.engine_use_ray:
639
640
641
642
643
            return await self.engine.get_tokenizer.remote(  # type: ignore
                lora_request)

        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
644

645
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
646
        """Start the background loop."""
647
648
649
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
650
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
651
            raise RuntimeError("Background loop is already running.")
652
653
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
654
655
656
657

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
658
            partial(_log_task_completion, error_callback=self._error_callback))
659
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
660
661
662

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
663
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
664
            engine_class = self._engine_class
665
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
666
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
667
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
668
669
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
670
671
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
672
673
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
674
675
676
677
678
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
679
680
        return engine_class(*args, **kwargs)

681
    async def engine_step(self, virtual_engine: int) -> bool:
682
683
684
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
685

686
687
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
688
689
690
691

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
692
693
            try:
                if self.engine_use_ray:
694
695
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
696
697
698
699
700
701
702
703
704
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
705

706
707
        if aborted_requests:
            await self._engine_abort(aborted_requests)
708

Zhuohan Li's avatar
Zhuohan Li committed
709
        if self.engine_use_ray:
710
            request_outputs = await self.engine.step.remote()  # type: ignore
711
        else:
712
            request_outputs = await self.engine.step_async(virtual_engine)
713

Antoni Baum's avatar
Antoni Baum committed
714
        # Put the outputs into the corresponding streams.
715
        finished = True
716
        for request_output in request_outputs:
717
            self._request_tracker.process_request_output(
718
                request_output, verbose=self.log_requests)
719
            finished = finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
720

721
        return not finished
722

Antoni Baum's avatar
Antoni Baum committed
723
724
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
725
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
726
727
728
729
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
730
731
732
733
734
735
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
736
        while True:
737
            if not any(has_requests_in_progress):
738
                logger.debug("Waiting for new requests...")
739
740
741
742
743
744
745
746
747
748
749
750
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
751
                await self._request_tracker.wait_for_new_requests()
752
                logger.debug("Got new requests!")
753
754
755
756
757
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
758
759
760
761

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
762
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
789
790
791
792
793
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
794
795
            await asyncio.sleep(0)

796
797
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
Antoni Baum's avatar
Antoni Baum committed
798
799
800
    async def add_request(
        self,
        request_id: str,
801
        inputs: PromptInputs,
802
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
803
        arrival_time: Optional[float] = None,
804
        lora_request: Optional[LoRARequest] = None,
805
        trace_headers: Optional[Mapping[str, str]] = None,
806
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
807
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
808
        if not self.is_running:
809
810
811
812
813
814
815
816
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
817

818
        stream = self._request_tracker.add_request(
819
            request_id,
820
            verbose=self.log_requests,
821
            inputs=inputs,
822
            params=params,
823
            arrival_time=arrival_time or time.time(),
824
            lora_request=lora_request,
825
            trace_headers=trace_headers,
826
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
827

828
        return stream.generator()
829

830
    async def generate(
831
        self,
832
        inputs: PromptInputs,
833
834
        sampling_params: SamplingParams,
        request_id: str,
835
        lora_request: Optional[LoRARequest] = None,
836
        trace_headers: Optional[Mapping[str, str]] = None,
837
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
838
    ) -> AsyncGenerator[RequestOutput, None]:
839
840
841
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
842
843
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
844
845

        Args:
846
847
848
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
849
850
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
851
            lora_request: LoRA request to use for generation, if any.
852
            trace_headers: OpenTelemetry trace headers.
853
854
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
855
856

        Yields:
857
858
            The output `RequestOutput` objects from the LLMEngine
            for the request.
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
902
        """
903
        async for output in await self.add_request(
904
                request_id,
905
                inputs,
906
                sampling_params,
907
                lora_request=lora_request,
908
                trace_headers=trace_headers,
909
                prompt_adapter_request=prompt_adapter_request,
910
        ):
911
            yield LLMEngine.validate_output(output, RequestOutput)
912
913
914

    async def encode(
        self,
915
        inputs: PromptInputs,
916
917
918
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
919
        trace_headers: Optional[Mapping[str, str]] = None,
920
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
921
922
923
924
925
926
927
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
928
929
930
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
931
932
933
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
934
            trace_headers: OpenTelemetry trace headers.
935
936

        Yields:
937
            The output `EmbeddingRequestOutput` objects from the LLMEngine
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
981
        async for output in await self.add_request(
982
                request_id,
983
                inputs,
984
                pooling_params,
985
                lora_request=lora_request,
986
                trace_headers=trace_headers,
987
        ):
988
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
989

Antoni Baum's avatar
Antoni Baum committed
990
991
    async def abort(self, request_id: str) -> None:
        """Abort a request.
992

Antoni Baum's avatar
Antoni Baum committed
993
994
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
995

Antoni Baum's avatar
Antoni Baum committed
996
997
998
        Args:
            request_id: The unique id of the request.
        """
999
1000
1001
1002
1003
1004
1005
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
1006
        return self._abort(request_id)
1007

Antoni Baum's avatar
Antoni Baum committed
1008
    def _abort(self, request_id: str) -> None:
1009
1010
1011
1012
1013
1014
1015
1016
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
1017
        self._request_tracker.abort_request(request_id,
1018
                                            cancelled=True,
1019
                                            verbose=self.log_requests)
1020

1021
1022
1023
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
1024
            return await self.engine.get_model_config.remote()  # type: ignore
1025
1026
1027
        else:
            return self.engine.get_model_config()

1028
1029
1030
1031
1032
1033
1034
1035
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_parallel_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_parallel_config()

1036
1037
1038
1039
1040
1041
1042
1043
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_scheduler_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_scheduler_config()

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_lora_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_lora_config()

1060
1061
1062
1063
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
1064
        if self.engine_use_ray:
1065
1066
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
1067
1068
        else:
            self.engine.do_log_stats()
1069

1070
    async def check_health(self) -> None:
1071
1072
1073
1074
1075
1076
1077
1078
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
1079
                await self.engine.check_health.remote()  # type: ignore
1080
1081
1082
1083
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
1084
        logger.debug("Health check took %fs", time.perf_counter() - t)
1085
1086
1087
1088
1089
1090
1091

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.add_logger.remote(  # type: ignore
                    logger_name=logger_name, logger=logger))
        else:
            self.engine.add_logger(logger_name=logger_name, logger=logger)

    def remove_logger(self, logger_name: str) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.remove_logger.remote(  # type: ignore
                    logger_name=logger_name))
        else:
            self.engine.remove_logger(logger_name=logger_name)