async_llm_engine.py 36.9 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
6

7
8
from transformers import PreTrainedTokenizer

9
import vllm.envs as envs
10
from vllm.config import DecodingConfig, ModelConfig
11
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.engine.arg_utils import AsyncEngineArgs
13
from vllm.engine.async_timeout import asyncio_timeout
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.engine.llm_engine import LLMEngine
15
from vllm.executor.ray_utils import initialize_ray_cluster, ray
16
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.logger import init_logger
18
from vllm.lora.request import LoRARequest
19
20
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.sampling_params import SamplingParams
22
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
23
from vllm.usage.usage_lib import UsageContext
24
25

logger = init_logger(__name__)
26
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
27

Antoni Baum's avatar
Antoni Baum committed
28

29
30
31
32
class AsyncEngineDeadError(RuntimeError):
    pass


33
34
35
36
37
38
39
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
40
41

    exception = None
42
    try:
43
44
45
46
47
48
49
50
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
51
52
53
54
55
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
56
57
58
            "Task finished unexpectedly. This should never happen! "
            "Please open an issue on Github. See stack trace above for the"
            "actual cause.") from e
59
60


Antoni Baum's avatar
Antoni Baum committed
61
class AsyncStream:
62
63
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
    that can be iterated over asynchronously."""
Antoni Baum's avatar
Antoni Baum committed
64
65
66

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
67
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
68
69
        self._finished = False

70
71
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
72
73
74
75
76
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
77
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
78
79
80
81
82
83
84
85
86
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

87
    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
88
        result = await self._queue.get()
89
        if isinstance(result, Exception):
90
            raise result
Antoni Baum's avatar
Antoni Baum committed
91
92
93
        return result


94
95
96
97
98
99
100
101
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
102
        self.new_requests_event = asyncio.Event()
103
104
105
106

    def __contains__(self, item):
        return item in self._request_streams

107
108
    def __len__(self) -> int:
        return len(self._request_streams)
109
110
111
112
113
114
115
116

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
117
            self.abort_request(request_id)
118
        else:
119
            for rid, stream in self._request_streams.items():
120
                stream.put(exc)
121
                self.abort_request(rid)
122
123

    def process_request_output(self,
124
125
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
126
127
128
129
130
131
132
133
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
134
                logger.info("Finished request %s.", request_id)
135
136
            self.abort_request(request_id)

137
138
139
140
141
142
143
144
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
145
            logger.info("Finished request %s.", request_id)
146
147
        self.abort_request(request_id)

148
149
150
151
152
153
154
155
156
157
158
159
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
160
161
162

        self.new_requests_event.set()

163
164
165
166
167
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
168
            logger.info("Aborted request %s.", request_id)
169
170
171
172
173
174
175
176
177
178

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

179
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
180
181
        """Get the new requests and finished requests to be
        sent to the engine."""
182
        new_requests: List[Dict] = []
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
200

201
    async def wait_for_new_requests(self):
202
203
204
205
206
207
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
208

Antoni Baum's avatar
Antoni Baum committed
209
210
211
212

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

213
    async def step_async(
214
215
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
216
217
218
219
220
221
222
223
224
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
225
226
        seq_group_metadata_list, scheduler_outputs = self.scheduler[
            virtual_engine].schedule()
Antoni Baum's avatar
Antoni Baum committed
227

228
229
        if not scheduler_outputs.is_empty():
            # Execute the model.
230
231
232
233
234
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
235
                virtual_engine=virtual_engine,
236
237
238
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
            )
239
            output = await self.model_executor.execute_model_async(
240
                execute_model_req)
241
242
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
243

244
        request_outputs = self._process_model_outputs(
245
            output, scheduler_outputs.scheduled_seq_groups,
246
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
247

248
        # Log stats.
249
        self.do_log_stats(scheduler_outputs, output)
250

251
252
253
        # Tracing
        self.do_tracing(scheduler_outputs)

254
255
        return request_outputs

256
257
258
259
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

260
    async def process_model_inputs_async(
261
        self,
262
263
        request_id: str,
        inputs: PromptInputs,
264
        lora_request: Optional[LoRARequest] = None,
265
266
267
268
269
270
271
272
273
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = await tokenizer.encode_async(
274
                request_id=request_id,
275
                prompt=inputs["prompt"],
276
                lora_request=lora_request)
277
278
279
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

280
281
282
283
284
        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
                               prompt=inputs.get("prompt"),
                               multi_modal_data=inputs.get("multi_modal_data"))

        return self.input_processor(llm_inputs)
285
286
287
288

    async def add_request_async(
        self,
        request_id: str,
289
        inputs: PromptInputs,
290
        params: Union[SamplingParams, PoolingParams],
291
292
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
293
        trace_headers: Optional[Dict[str, str]] = None,
294
295
296
297
298
299
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
300
301
302
303
304

        processed_inputs = await self.process_model_inputs_async(
            request_id=request_id, inputs=inputs, lora_request=lora_request)

        self._add_processed_request(
305
            request_id=request_id,
306
307
308
309
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
310
            trace_headers=trace_headers,
311
        )
312

313
    async def check_health_async(self) -> None:
314
315
        if self.tokenizer:
            self.tokenizer.check_health()
316
        self.model_executor.check_health()
317

318

319
class AsyncLLMEngine:
320
    """An asynchronous wrapper for :class:`LLMEngine`.
321

322
323
324
325
326
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
327
328
329
330
331

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
332
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
333
334
            async frontend will be executed in a separate process as the
            model workers.
335
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
336
337
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
338
339
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
340
341
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
342
    """
343

Antoni Baum's avatar
Antoni Baum committed
344
345
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

346
347
348
349
350
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
351
                 max_log_len: Optional[int] = None,
352
                 start_engine_loop: bool = True,
353
                 **kwargs) -> None:
354
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
355
        self.engine_use_ray = engine_use_ray
356
        self.log_requests = log_requests
357
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
358
359
        self.engine = self._init_engine(*args, **kwargs)

360
        self.background_loop: Optional[asyncio.Future] = None
361
362
363
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
364
        self._background_loop_unshielded: Optional[asyncio.Task] = None
365
        self.start_engine_loop = start_engine_loop
366
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
367

368
369
370
        # Lazy initialized fields
        self._request_tracker: RequestTracker

371
    @classmethod
yhu422's avatar
yhu422 committed
372
373
374
375
376
377
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "AsyncLLMEngine":
378
379
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
380
        engine_config = engine_args.create_engine_config()
381
382
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
383

384
        if engine_config.device_config.device_type == "neuron":
385
386
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
387
388
389
        elif engine_config.device_config.device_type == "tpu":
            from vllm.executor.tpu_executor import TPUExecutorAsync
            executor_class = TPUExecutorAsync
390
        elif engine_config.device_config.device_type == "cpu":
391
392
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with the CPU backend.")
393
394
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
395
396
397
398
399
400
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
401
402
403
404
405
406
407
408
409
410
411
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
412
        elif distributed_executor_backend == "ray":
413
            initialize_ray_cluster(engine_config.parallel_config)
414
415
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
416
417
418
419
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
420
421
422
423
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
424
        engine = cls(
425
            distributed_executor_backend == "ray",
yhu422's avatar
yhu422 committed
426
            engine_args.engine_use_ray,
427
428
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
429
430
431
432
433
434
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            max_log_len=engine_args.max_log_len,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
        )
435
436
        return engine

437
438
    @property
    def is_running(self) -> bool:
439
        return (self.background_loop is not None
440
                and self._background_loop_unshielded is not None
441
442
443
444
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
445
446
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
447
448
449
450
451
452
453
454
455
456
457
458
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
459

460
461
    async def get_tokenizer(self) -> "PreTrainedTokenizer":
        if self.engine_use_ray:
462
            return await self.engine.get_tokenizer.remote()  # type: ignore
463
464
        else:
            return self.engine.get_tokenizer()
465

466
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
467
        """Start the background loop."""
468
469
470
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
471
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
472
            raise RuntimeError("Background loop is already running.")
473
474
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
475
476
477
478

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
479
            partial(_log_task_completion, error_callback=self._error_callback))
480
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
481
482
483

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
484
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
485
            engine_class = self._engine_class
486
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
487
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
488
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
489
490
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
491
492
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
493
494
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
495
496
497
498
499
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
500
501
        return engine_class(*args, **kwargs)

502
    async def engine_step(self, virtual_engine: int) -> bool:
503
504
505
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
506
507

        new_requests, finished_requests = (
508
            self._request_tracker.get_new_and_finished_requests())
509
510
511
512

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
513
514
            try:
                if self.engine_use_ray:
515
516
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
517
518
519
520
521
522
523
524
525
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
526
527
528
529

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
530
        if self.engine_use_ray:
531
            request_outputs = await self.engine.step.remote()  # type: ignore
532
        else:
533
            request_outputs = await self.engine.step_async(virtual_engine)
534

Antoni Baum's avatar
Antoni Baum committed
535
        # Put the outputs into the corresponding streams.
536
        for request_output in request_outputs:
537
            self._request_tracker.process_request_output(
538
                request_output, verbose=self.log_requests)
Antoni Baum's avatar
Antoni Baum committed
539

540
541
        return len(request_outputs) > 0

Antoni Baum's avatar
Antoni Baum committed
542
543
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
544
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
545
546
547
548
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
549
550
551
552
553
554
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
555
        while True:
556
            if not any(has_requests_in_progress):
557
                logger.debug("Waiting for new requests...")
558
559
560
561
562
563
564
565
566
567
568
569
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
570
                await self._request_tracker.wait_for_new_requests()
571
                logger.debug("Got new requests!")
572
573
574
575
576
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
577
578
579
580

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
581
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
608
609
610
611
612
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
613
614
615
616
617
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
618
        inputs: PromptInputs,
619
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
620
        arrival_time: Optional[float] = None,
621
        lora_request: Optional[LoRARequest] = None,
622
        trace_headers: Optional[Dict[str, str]] = None,
Antoni Baum's avatar
Antoni Baum committed
623
624
    ) -> AsyncStream:
        if self.log_requests:
625
626
627
628
629
630
631
632
633
            if isinstance(inputs, str):
                shortened_prompt = inputs
                shortened_token_ids = None
            else:
                shortened_prompt = inputs.get("prompt")
                shortened_token_ids = inputs.get("prompt_token_ids")

            max_log_len = self.max_log_len
            if max_log_len is not None:
634
                if shortened_prompt is not None:
635
                    shortened_prompt = shortened_prompt[:max_log_len]
636
                if shortened_token_ids is not None:
637
638
                    shortened_token_ids = shortened_token_ids[:max_log_len]

639
640
            logger.info(
                "Received request %s: prompt: %r, "
641
642
643
                "params: %s, prompt_token_ids: %s, "
                "lora_request: %s.", request_id, shortened_prompt, params,
                shortened_token_ids, lora_request)
Antoni Baum's avatar
Antoni Baum committed
644

645
        if not self.is_running:
646
647
648
649
650
651
652
653
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
654

655
656
        if arrival_time is None:
            arrival_time = time.time()
657

658
        stream = self._request_tracker.add_request(
659
            request_id,
660
            inputs=inputs,
661
            params=params,
662
            arrival_time=arrival_time,
663
            lora_request=lora_request,
664
            trace_headers=trace_headers,
665
        )
Antoni Baum's avatar
Antoni Baum committed
666
667

        return stream
668

669
    async def generate(
670
        self,
671
        inputs: PromptInputs,
672
673
        sampling_params: SamplingParams,
        request_id: str,
674
        lora_request: Optional[LoRARequest] = None,
675
        trace_headers: Optional[Dict[str, str]] = None,
676
    ) -> AsyncIterator[RequestOutput]:
677
678
679
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
680
681
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
682
683

        Args:
684
685
686
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
687
688
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
689
            lora_request: LoRA request to use for generation, if any.
690
            trace_headers: OpenTelemetry trace headers.
691
692

        Yields:
693
694
            The output `RequestOutput` objects from the LLMEngine
            for the request.
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
738
        """
739
        async for output in self._process_request(
740
                request_id,
741
                inputs,
742
                sampling_params,
743
                lora_request=lora_request,
744
                trace_headers=trace_headers,
745
        ):
746
            yield LLMEngine.validate_output(output, RequestOutput)
747
748
749

    async def encode(
        self,
750
        inputs: PromptInputs,
751
752
753
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
754
        trace_headers: Optional[Dict[str, str]] = None,
755
756
757
758
759
760
761
762
    ) -> AsyncIterator[EmbeddingRequestOutput]:
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
763
764
765
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
766
767
768
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
769
            trace_headers: OpenTelemetry trace headers.
770
771

        Yields:
772
            The output `EmbeddingRequestOutput` objects from the LLMEngine
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
816
        async for output in self._process_request(
817
                request_id,
818
                inputs,
819
                pooling_params,
820
                lora_request=lora_request,
821
                trace_headers=trace_headers,
822
        ):
823
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
824

825
    async def _process_request(
826
827
        self,
        request_id: str,
828
        inputs: PromptInputs,
829
        params: Union[SamplingParams, PoolingParams],
830
        *,
831
        lora_request: Optional[LoRARequest] = None,
832
        trace_headers: Optional[Dict[str, str]] = None,
833
834
835
836
837
838
839
    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
        """Common logic to process requests with SamplingParams or
        PoolingParams."""
        arrival_time = time.time()

        stream = await self.add_request(
            request_id,
840
            inputs,
841
842
843
            params,
            arrival_time=arrival_time,
            lora_request=lora_request,
844
            trace_headers=trace_headers,
845
        )
846

847
        try:
Antoni Baum's avatar
Antoni Baum committed
848
849
            async for request_output in stream:
                yield request_output
850
        except (Exception, asyncio.CancelledError) as e:
Antoni Baum's avatar
Antoni Baum committed
851
852
            self._abort(request_id)
            raise e
853

Antoni Baum's avatar
Antoni Baum committed
854
855
    async def abort(self, request_id: str) -> None:
        """Abort a request.
856

Antoni Baum's avatar
Antoni Baum committed
857
858
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
859

Antoni Baum's avatar
Antoni Baum committed
860
861
862
        Args:
            request_id: The unique id of the request.
        """
863
864
865
866
867
868
869
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
870
        return self._abort(request_id)
871

Antoni Baum's avatar
Antoni Baum committed
872
    def _abort(self, request_id: str) -> None:
873
874
875
876
877
878
879
880
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
881
882
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
883

884
885
886
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
887
            return await self.engine.get_model_config.remote()  # type: ignore
888
889
890
        else:
            return self.engine.get_model_config()

891
892
893
894
895
896
897
898
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

899
900
901
902
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
903
        if self.engine_use_ray:
904
905
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
906
907
        else:
            self.engine.do_log_stats()
908

909
    async def check_health(self) -> None:
910
911
912
913
914
915
916
917
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
918
                await self.engine.check_health.remote()  # type: ignore
919
920
921
922
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
923
        logger.debug("Health check took %fs", time.perf_counter() - t)
924
925
926
927
928
929
930

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()