async_llm_engine.py 38.3 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
6

7
8
from transformers import PreTrainedTokenizer

9
import vllm.envs as envs
10
from vllm.config import DecodingConfig, ModelConfig
11
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.engine.arg_utils import AsyncEngineArgs
13
from vllm.engine.async_timeout import asyncio_timeout
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.engine.llm_engine import LLMEngine
15
from vllm.executor.ray_utils import initialize_ray_cluster, ray
16
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.logger import init_logger
18
from vllm.lora.request import LoRARequest
19
20
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
21
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.sampling_params import SamplingParams
23
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
24
from vllm.usage.usage_lib import UsageContext
25
26

logger = init_logger(__name__)
27
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
28

Antoni Baum's avatar
Antoni Baum committed
29

30
31
32
33
class AsyncEngineDeadError(RuntimeError):
    pass


34
35
36
37
38
39
40
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
41
42

    exception = None
43
    try:
44
45
46
47
48
49
50
51
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
52
53
54
55
56
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
57
58
59
            "Task finished unexpectedly. This should never happen! "
            "Please open an issue on Github. See stack trace above for the"
            "actual cause.") from e
60
61


Antoni Baum's avatar
Antoni Baum committed
62
class AsyncStream:
63
64
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
    that can be iterated over asynchronously."""
Antoni Baum's avatar
Antoni Baum committed
65
66
67

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
68
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
69
70
        self._finished = False

71
72
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
73
74
75
76
77
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
78
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
79
80
81
82
83
84
85
86
87
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

88
    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
89
        result = await self._queue.get()
90
        if isinstance(result, Exception):
91
            raise result
Antoni Baum's avatar
Antoni Baum committed
92
93
94
        return result


95
96
97
98
99
100
101
102
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
103
        self.new_requests_event = asyncio.Event()
104
105
106
107

    def __contains__(self, item):
        return item in self._request_streams

108
109
    def __len__(self) -> int:
        return len(self._request_streams)
110
111
112
113
114
115
116
117

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
118
            self.abort_request(request_id)
119
        else:
120
            for rid, stream in self._request_streams.items():
121
                stream.put(exc)
122
                self.abort_request(rid)
123
124

    def process_request_output(self,
125
126
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
127
128
129
130
131
132
133
134
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
135
                logger.info("Finished request %s.", request_id)
136
137
            self.abort_request(request_id)

138
139
140
141
142
143
144
145
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
146
            logger.info("Finished request %s.", request_id)
147
148
        self.abort_request(request_id)

149
150
151
152
153
154
155
156
157
158
159
160
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
161
162
163

        self.new_requests_event.set()

164
165
166
167
168
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
169
            logger.info("Aborted request %s.", request_id)
170
171
172
173
174
175
176
177
178
179

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

180
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
181
182
        """Get the new requests and finished requests to be
        sent to the engine."""
183
        new_requests: List[Dict] = []
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
201

202
    async def wait_for_new_requests(self):
203
204
205
206
207
208
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
209

Antoni Baum's avatar
Antoni Baum committed
210
211
212
213

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

214
    async def step_async(
215
216
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
217
218
219
220
221
222
223
224
225
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
226
227
        seq_group_metadata_list, scheduler_outputs = self.scheduler[
            virtual_engine].schedule()
Mor Zusman's avatar
Mor Zusman committed
228
229
        finished_requests_ids = self.scheduler[
            virtual_engine].get_and_reset_finished_requests_ids()
Antoni Baum's avatar
Antoni Baum committed
230

231
232
        if not scheduler_outputs.is_empty():
            # Execute the model.
233
234
235
236
237
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
238
                virtual_engine=virtual_engine,
239
240
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
Mor Zusman's avatar
Mor Zusman committed
241
                finished_requests_ids=finished_requests_ids)
242
            output = await self.model_executor.execute_model_async(
243
                execute_model_req)
244
245
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
246

247
        request_outputs = self._process_model_outputs(
248
            output, scheduler_outputs.scheduled_seq_groups,
249
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
250

251
        # Log stats.
252
        self.do_log_stats(scheduler_outputs, output)
253

254
255
256
        # Tracing
        self.do_tracing(scheduler_outputs)

257
258
        return request_outputs

259
260
261
262
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

263
    async def process_model_inputs_async(
264
        self,
265
266
        request_id: str,
        inputs: PromptInputs,
267
        lora_request: Optional[LoRARequest] = None,
268
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
269
270
271
272
273
274
275
276
277
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = await tokenizer.encode_async(
278
                request_id=request_id,
279
                prompt=inputs["prompt"],
280
                lora_request=lora_request)
281
282
283
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

284
285
286
287
288
289
        if prompt_adapter_request:
            prompt_token_ids = [
                0
            ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
                prompt_token_ids

290
291
292
293
294
        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
                               prompt=inputs.get("prompt"),
                               multi_modal_data=inputs.get("multi_modal_data"))

        return self.input_processor(llm_inputs)
295
296

    async def add_request_async(
297
298
299
300
301
302
303
304
            self,
            request_id: str,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Dict[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
305
306
307
308
309
310
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
311
312

        processed_inputs = await self.process_model_inputs_async(
313
314
315
316
            request_id=request_id,
            inputs=inputs,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
317
318

        self._add_processed_request(
319
            request_id=request_id,
320
321
322
323
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
324
            prompt_adapter_request=prompt_adapter_request,
325
            trace_headers=trace_headers,
326
        )
327

328
    async def check_health_async(self) -> None:
329
330
        if self.tokenizer:
            self.tokenizer.check_health()
331
        self.model_executor.check_health()
332

333

334
class AsyncLLMEngine:
335
    """An asynchronous wrapper for :class:`LLMEngine`.
336

337
338
339
340
341
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
342
343
344
345
346

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
347
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
348
349
            async frontend will be executed in a separate process as the
            model workers.
350
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
351
352
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
353
354
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
355
356
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
357
    """
358

Antoni Baum's avatar
Antoni Baum committed
359
360
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

361
362
363
364
365
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
366
                 max_log_len: Optional[int] = None,
367
                 start_engine_loop: bool = True,
368
                 **kwargs) -> None:
369
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
370
        self.engine_use_ray = engine_use_ray
371
        self.log_requests = log_requests
372
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
373
374
        self.engine = self._init_engine(*args, **kwargs)

375
        self.background_loop: Optional[asyncio.Future] = None
376
377
378
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
379
        self._background_loop_unshielded: Optional[asyncio.Task] = None
380
        self.start_engine_loop = start_engine_loop
381
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
382

383
384
385
        # Lazy initialized fields
        self._request_tracker: RequestTracker

386
    @classmethod
yhu422's avatar
yhu422 committed
387
388
389
390
391
392
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "AsyncLLMEngine":
393
394
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
395
        engine_config = engine_args.create_engine_config()
396
397
398
399
400

        if engine_args.engine_use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

401
402
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
403

404
        if engine_config.device_config.device_type == "neuron":
405
406
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
407
408
409
        elif engine_config.device_config.device_type == "tpu":
            from vllm.executor.tpu_executor import TPUExecutorAsync
            executor_class = TPUExecutorAsync
410
        elif engine_config.device_config.device_type == "cpu":
411
412
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with the CPU backend.")
413
414
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
415
416
417
418
419
420
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
421
422
423
424
425
426
427
428
429
430
431
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
432
        elif distributed_executor_backend == "ray":
433
            initialize_ray_cluster(engine_config.parallel_config)
434
435
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
436
437
438
439
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
440
441
442
443
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
444
        engine = cls(
445
            distributed_executor_backend == "ray",
yhu422's avatar
yhu422 committed
446
            engine_args.engine_use_ray,
447
448
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
449
450
451
452
453
454
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            max_log_len=engine_args.max_log_len,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
        )
455
456
        return engine

457
458
    @property
    def is_running(self) -> bool:
459
        return (self.background_loop is not None
460
                and self._background_loop_unshielded is not None
461
462
463
464
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
465
466
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
467
468
469
470
471
472
473
474
475
476
477
478
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
479

480
481
    async def get_tokenizer(self) -> "PreTrainedTokenizer":
        if self.engine_use_ray:
482
            return await self.engine.get_tokenizer.remote()  # type: ignore
483
484
        else:
            return self.engine.get_tokenizer()
485

486
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
487
        """Start the background loop."""
488
489
490
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
491
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
492
            raise RuntimeError("Background loop is already running.")
493
494
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
495
496
497
498

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
499
            partial(_log_task_completion, error_callback=self._error_callback))
500
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
501
502
503

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
504
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
505
            engine_class = self._engine_class
506
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
507
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
508
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
509
510
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
511
512
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
513
514
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
515
516
517
518
519
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
520
521
        return engine_class(*args, **kwargs)

522
    async def engine_step(self, virtual_engine: int) -> bool:
523
524
525
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
526
527

        new_requests, finished_requests = (
528
            self._request_tracker.get_new_and_finished_requests())
529
530
531
532

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
533
534
            try:
                if self.engine_use_ray:
535
536
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
537
538
539
540
541
542
543
544
545
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
546
547
548
549

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
550
        if self.engine_use_ray:
551
            request_outputs = await self.engine.step.remote()  # type: ignore
552
        else:
553
            request_outputs = await self.engine.step_async(virtual_engine)
554

Antoni Baum's avatar
Antoni Baum committed
555
        # Put the outputs into the corresponding streams.
556
        finished = True
557
        for request_output in request_outputs:
558
            self._request_tracker.process_request_output(
559
                request_output, verbose=self.log_requests)
560
            finished = finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
561

562
        return not finished
563

Antoni Baum's avatar
Antoni Baum committed
564
565
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
566
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
567
568
569
570
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
571
572
573
574
575
576
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
577
        while True:
578
            if not any(has_requests_in_progress):
579
                logger.debug("Waiting for new requests...")
580
581
582
583
584
585
586
587
588
589
590
591
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
592
                await self._request_tracker.wait_for_new_requests()
593
                logger.debug("Got new requests!")
594
595
596
597
598
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
599
600
601
602

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
603
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
630
631
632
633
634
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
635
636
637
638
639
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
640
        inputs: PromptInputs,
641
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
642
        arrival_time: Optional[float] = None,
643
        lora_request: Optional[LoRARequest] = None,
644
        trace_headers: Optional[Dict[str, str]] = None,
645
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
Antoni Baum's avatar
Antoni Baum committed
646
647
    ) -> AsyncStream:
        if self.log_requests:
648
649
650
651
652
653
654
655
656
            if isinstance(inputs, str):
                shortened_prompt = inputs
                shortened_token_ids = None
            else:
                shortened_prompt = inputs.get("prompt")
                shortened_token_ids = inputs.get("prompt_token_ids")

            max_log_len = self.max_log_len
            if max_log_len is not None:
657
                if shortened_prompt is not None:
658
                    shortened_prompt = shortened_prompt[:max_log_len]
659
                if shortened_token_ids is not None:
660
661
                    shortened_token_ids = shortened_token_ids[:max_log_len]

662
663
            logger.info(
                "Received request %s: prompt: %r, "
664
665
666
                "params: %s, prompt_token_ids: %s, "
                "lora_request: %s.", request_id, shortened_prompt, params,
                shortened_token_ids, lora_request)
Antoni Baum's avatar
Antoni Baum committed
667

668
        if not self.is_running:
669
670
671
672
673
674
675
676
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
677

678
679
        if arrival_time is None:
            arrival_time = time.time()
680

681
        stream = self._request_tracker.add_request(
682
            request_id,
683
            inputs=inputs,
684
            params=params,
685
            arrival_time=arrival_time,
686
            lora_request=lora_request,
687
            trace_headers=trace_headers,
688
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
689
690

        return stream
691

692
    async def generate(
693
        self,
694
        inputs: PromptInputs,
695
696
        sampling_params: SamplingParams,
        request_id: str,
697
        lora_request: Optional[LoRARequest] = None,
698
        trace_headers: Optional[Dict[str, str]] = None,
699
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
700
    ) -> AsyncIterator[RequestOutput]:
701
702
703
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
704
705
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
706
707

        Args:
708
709
710
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
711
712
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
713
            lora_request: LoRA request to use for generation, if any.
714
            trace_headers: OpenTelemetry trace headers.
715
716
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
717
718

        Yields:
719
720
            The output `RequestOutput` objects from the LLMEngine
            for the request.
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
764
        """
765
        async for output in self._process_request(
766
                request_id,
767
                inputs,
768
                sampling_params,
769
                lora_request=lora_request,
770
                trace_headers=trace_headers,
771
                prompt_adapter_request=prompt_adapter_request,
772
        ):
773
            yield LLMEngine.validate_output(output, RequestOutput)
774
775
776

    async def encode(
        self,
777
        inputs: PromptInputs,
778
779
780
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
781
        trace_headers: Optional[Dict[str, str]] = None,
782
783
784
785
786
787
788
789
    ) -> AsyncIterator[EmbeddingRequestOutput]:
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
790
791
792
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
793
794
795
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
796
            trace_headers: OpenTelemetry trace headers.
797
798

        Yields:
799
            The output `EmbeddingRequestOutput` objects from the LLMEngine
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
843
        async for output in self._process_request(
844
                request_id,
845
                inputs,
846
                pooling_params,
847
                lora_request=lora_request,
848
                trace_headers=trace_headers,
849
        ):
850
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
851

852
    async def _process_request(
853
854
        self,
        request_id: str,
855
        inputs: PromptInputs,
856
        params: Union[SamplingParams, PoolingParams],
857
        *,
858
        lora_request: Optional[LoRARequest] = None,
859
        trace_headers: Optional[Dict[str, str]] = None,
860
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
861
862
863
864
865
866
867
    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
        """Common logic to process requests with SamplingParams or
        PoolingParams."""
        arrival_time = time.time()

        stream = await self.add_request(
            request_id,
868
            inputs,
869
870
871
            params,
            arrival_time=arrival_time,
            lora_request=lora_request,
872
            trace_headers=trace_headers,
873
            prompt_adapter_request=prompt_adapter_request,
874
        )
875

876
        try:
Antoni Baum's avatar
Antoni Baum committed
877
878
            async for request_output in stream:
                yield request_output
879
        except (Exception, asyncio.CancelledError) as e:
Antoni Baum's avatar
Antoni Baum committed
880
881
            self._abort(request_id)
            raise e
882

Antoni Baum's avatar
Antoni Baum committed
883
884
    async def abort(self, request_id: str) -> None:
        """Abort a request.
885

Antoni Baum's avatar
Antoni Baum committed
886
887
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
888

Antoni Baum's avatar
Antoni Baum committed
889
890
891
        Args:
            request_id: The unique id of the request.
        """
892
893
894
895
896
897
898
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
899
        return self._abort(request_id)
900

Antoni Baum's avatar
Antoni Baum committed
901
    def _abort(self, request_id: str) -> None:
902
903
904
905
906
907
908
909
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
910
911
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
912

913
914
915
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
916
            return await self.engine.get_model_config.remote()  # type: ignore
917
918
919
        else:
            return self.engine.get_model_config()

920
921
922
923
924
925
926
927
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

928
929
930
931
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
932
        if self.engine_use_ray:
933
934
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
935
936
        else:
            self.engine.do_log_stats()
937

938
    async def check_health(self) -> None:
939
940
941
942
943
944
945
946
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
947
                await self.engine.check_health.remote()  # type: ignore
948
949
950
951
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
952
        logger.debug("Health check took %fs", time.perf_counter() - t)
953
954
955
956
957
958
959

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()