async_llm_engine.py 34 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
6

7
8
from transformers import PreTrainedTokenizer

9
import vllm.envs as envs
10
from vllm.config import DecodingConfig, ModelConfig
11
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
14
from vllm.executor.ray_utils import initialize_ray_cluster, ray
15
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
19
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.sampling_params import SamplingParams
21
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
22
from vllm.usage.usage_lib import UsageContext
23
24

logger = init_logger(__name__)
25
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
26

Antoni Baum's avatar
Antoni Baum committed
27

28
29
30
31
class AsyncEngineDeadError(RuntimeError):
    pass


32
33
34
35
36
37
38
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
39
40

    exception = None
41
    try:
42
43
44
45
46
47
48
49
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
50
51
52
53
54
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
55
56
57
            "Task finished unexpectedly. This should never happen! "
            "Please open an issue on Github. See stack trace above for the"
            "actual cause.") from e
58
59


Antoni Baum's avatar
Antoni Baum committed
60
class AsyncStream:
61
62
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
    that can be iterated over asynchronously."""
Antoni Baum's avatar
Antoni Baum committed
63
64
65

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
66
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
67
68
        self._finished = False

69
70
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
71
72
73
74
75
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
76
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
77
78
79
80
81
82
83
84
85
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

86
    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
87
        result = await self._queue.get()
88
        if isinstance(result, Exception):
89
            raise result
Antoni Baum's avatar
Antoni Baum committed
90
91
92
        return result


93
94
95
96
97
98
99
100
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
101
        self.new_requests_event = asyncio.Event()
102
103
104
105

    def __contains__(self, item):
        return item in self._request_streams

106
107
    def __len__(self) -> int:
        return len(self._request_streams)
108
109
110
111
112
113
114
115

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
116
            self.abort_request(request_id)
117
        else:
118
            for rid, stream in self._request_streams.items():
119
                stream.put(exc)
120
                self.abort_request(rid)
121
122

    def process_request_output(self,
123
124
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
125
126
127
128
129
130
131
132
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
133
                logger.info("Finished request %s.", request_id)
134
135
            self.abort_request(request_id)

136
137
138
139
140
141
142
143
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
144
            logger.info("Finished request %s.", request_id)
145
146
        self.abort_request(request_id)

147
148
149
150
151
152
153
154
155
156
157
158
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
159
160
161

        self.new_requests_event.set()

162
163
164
165
166
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
167
            logger.info("Aborted request %s.", request_id)
168
169
170
171
172
173
174
175
176
177

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

178
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
179
180
        """Get the new requests and finished requests to be
        sent to the engine."""
181
        new_requests: List[Dict] = []
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
199

200
    async def wait_for_new_requests(self):
201
202
203
204
205
206
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
207

Antoni Baum's avatar
Antoni Baum committed
208
209
210
211

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

212
213
    async def step_async(
            self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
214
215
216
217
218
219
220
221
222
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
223
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
224

225
226
        if not scheduler_outputs.is_empty():
            # Execute the model.
227
228
229
230
231
232
233
234
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
            )
235
            output = await self.model_executor.execute_model_async(
236
                execute_model_req)
237
238
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
239

240
        request_outputs = self._process_model_outputs(
241
            output, scheduler_outputs.scheduled_seq_groups,
242
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
243

244
        # Log stats.
245
        self.do_log_stats(scheduler_outputs, output)
246

247
248
249
        # Tracing
        self.do_tracing(scheduler_outputs)

250
251
252
253
254
255
256
257
        if not request_outputs:
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            await self.model_executor.stop_remote_worker_execution_loop_async()

258
259
        return request_outputs

260
    async def process_model_inputs_async(
261
        self,
262
263
        request_id: str,
        inputs: PromptInputs,
264
        lora_request: Optional[LoRARequest] = None,
265
266
267
268
269
270
271
272
273
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = await tokenizer.encode_async(
274
                request_id=request_id,
275
                prompt=inputs["prompt"],
276
                lora_request=lora_request)
277
278
279
280
281
282
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

        return LLMInputs(prompt_token_ids=prompt_token_ids,
                         prompt=inputs.get("prompt"),
                         multi_modal_data=inputs.get("multi_modal_data"))
283
284
285
286

    async def add_request_async(
        self,
        request_id: str,
287
        inputs: PromptInputs,
288
        params: Union[SamplingParams, PoolingParams],
289
290
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
291
        trace_headers: Optional[Dict[str, str]] = None,
292
293
294
295
296
297
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
298
299
300
301
302

        processed_inputs = await self.process_model_inputs_async(
            request_id=request_id, inputs=inputs, lora_request=lora_request)

        self._add_processed_request(
303
            request_id=request_id,
304
305
306
307
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
308
            trace_headers=trace_headers,
309
        )
310

311
312
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
313

314

315
class AsyncLLMEngine:
316
    """An asynchronous wrapper for :class:`LLMEngine`.
317

318
319
320
321
322
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
323
324
325
326
327

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
328
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
329
330
            async frontend will be executed in a separate process as the
            model workers.
331
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
332
333
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
334
335
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
336
337
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
338
    """
339

Antoni Baum's avatar
Antoni Baum committed
340
341
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

342
343
344
345
346
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
347
                 max_log_len: Optional[int] = None,
348
                 start_engine_loop: bool = True,
349
                 **kwargs) -> None:
350
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
351
        self.engine_use_ray = engine_use_ray
352
        self.log_requests = log_requests
353
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
354
355
        self.engine = self._init_engine(*args, **kwargs)

356
        self.background_loop: Optional[asyncio.Future] = None
357
358
359
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
360
        self._background_loop_unshielded: Optional[asyncio.Task] = None
361
        self.start_engine_loop = start_engine_loop
362
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
363

364
365
366
        # Lazy initialized fields
        self._request_tracker: RequestTracker

367
    @classmethod
yhu422's avatar
yhu422 committed
368
369
370
371
372
373
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "AsyncLLMEngine":
374
375
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
376
        engine_config = engine_args.create_engine_config()
377
378
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
379

380
        if engine_config.device_config.device_type == "neuron":
381
382
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
383
384
385
        elif engine_config.device_config.device_type == "tpu":
            from vllm.executor.tpu_executor import TPUExecutorAsync
            executor_class = TPUExecutorAsync
386
        elif engine_config.device_config.device_type == "cpu":
387
388
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with the CPU backend.")
389
390
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
391
392
393
394
395
396
397
398
399
400
401
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
402
        elif distributed_executor_backend == "ray":
403
            initialize_ray_cluster(engine_config.parallel_config)
404
405
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
406
407
408
409
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
410
411
412
413
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
414
        engine = cls(
415
            distributed_executor_backend == "ray",
yhu422's avatar
yhu422 committed
416
            engine_args.engine_use_ray,
417
418
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
419
420
421
422
423
424
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            max_log_len=engine_args.max_log_len,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
        )
425
426
        return engine

427
428
    @property
    def is_running(self) -> bool:
429
        return (self.background_loop is not None
430
                and self._background_loop_unshielded is not None
431
432
433
434
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
435
436
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
437
438
439
440
441
442
443
444
445
446
447
448
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
449

450
451
    async def get_tokenizer(self) -> "PreTrainedTokenizer":
        if self.engine_use_ray:
452
            return await self.engine.get_tokenizer.remote()  # type: ignore
453
454
        else:
            return self.engine.get_tokenizer()
455

456
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
457
        """Start the background loop."""
458
459
460
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
461
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
462
            raise RuntimeError("Background loop is already running.")
463
464
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
465
466
467
468

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
469
            partial(_log_task_completion, error_callback=self._error_callback))
470
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
471
472
473

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
474
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
475
            engine_class = self._engine_class
476
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
477
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
478
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
479
480
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
481
482
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
Woosuk Kwon's avatar
Woosuk Kwon committed
483
484
485
486
487
488
            if parallel_config.tensor_parallel_size == 1:
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
489
490
        return engine_class(*args, **kwargs)

491
492
493
494
    async def engine_step(self) -> bool:
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
495
496

        new_requests, finished_requests = (
497
            self._request_tracker.get_new_and_finished_requests())
498
499
500
501

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
502
503
            try:
                if self.engine_use_ray:
504
505
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
506
507
508
509
510
511
512
513
514
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
515
516
517
518

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
519
        if self.engine_use_ray:
520
            request_outputs = await self.engine.step.remote()  # type: ignore
521
        else:
Antoni Baum's avatar
Antoni Baum committed
522
            request_outputs = await self.engine.step_async()
523

Antoni Baum's avatar
Antoni Baum committed
524
        # Put the outputs into the corresponding streams.
525
        for request_output in request_outputs:
526
            self._request_tracker.process_request_output(
527
                request_output, verbose=self.log_requests)
Antoni Baum's avatar
Antoni Baum committed
528

529
530
        return len(request_outputs) > 0

Antoni Baum's avatar
Antoni Baum committed
531
532
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
533
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
534
535
536
537
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
538
        has_requests_in_progress = False
Antoni Baum's avatar
Antoni Baum committed
539
        while True:
540
            if not has_requests_in_progress:
541
                logger.debug("Waiting for new requests...")
542
                await self._request_tracker.wait_for_new_requests()
543
544
545
546
547
548
549
550
551
552
553
554
                logger.debug("Got new requests!")

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
                has_requests_in_progress = await asyncio.wait_for(
                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
555
556
557
558
559
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
560
        inputs: PromptInputs,
561
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
562
        arrival_time: Optional[float] = None,
563
        lora_request: Optional[LoRARequest] = None,
564
        trace_headers: Optional[Dict[str, str]] = None,
Antoni Baum's avatar
Antoni Baum committed
565
566
    ) -> AsyncStream:
        if self.log_requests:
567
568
569
570
571
572
573
574
575
            if isinstance(inputs, str):
                shortened_prompt = inputs
                shortened_token_ids = None
            else:
                shortened_prompt = inputs.get("prompt")
                shortened_token_ids = inputs.get("prompt_token_ids")

            max_log_len = self.max_log_len
            if max_log_len is not None:
576
                if shortened_prompt is not None:
577
                    shortened_prompt = shortened_prompt[:max_log_len]
578
                if shortened_token_ids is not None:
579
580
                    shortened_token_ids = shortened_token_ids[:max_log_len]

581
582
            logger.info(
                "Received request %s: prompt: %r, "
583
584
585
                "params: %s, prompt_token_ids: %s, "
                "lora_request: %s.", request_id, shortened_prompt, params,
                shortened_token_ids, lora_request)
Antoni Baum's avatar
Antoni Baum committed
586

587
        if not self.is_running:
588
589
590
591
592
593
594
595
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
596

597
598
        if arrival_time is None:
            arrival_time = time.time()
599

600
        stream = self._request_tracker.add_request(
601
            request_id,
602
            inputs=inputs,
603
            params=params,
604
            arrival_time=arrival_time,
605
            lora_request=lora_request,
606
            trace_headers=trace_headers,
607
        )
Antoni Baum's avatar
Antoni Baum committed
608
609

        return stream
610

611
    async def generate(
612
        self,
613
        inputs: PromptInputs,
614
615
        sampling_params: SamplingParams,
        request_id: str,
616
        lora_request: Optional[LoRARequest] = None,
617
        trace_headers: Optional[Dict[str, str]] = None,
618
    ) -> AsyncIterator[RequestOutput]:
619
620
621
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
622
623
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
624
625

        Args:
626
627
628
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
629
630
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
631
            lora_request: LoRA request to use for generation, if any.
632
            trace_headers: OpenTelemetry trace headers.
633
634

        Yields:
635
636
            The output `RequestOutput` objects from the LLMEngine
            for the request.
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
680
        """
681
        async for output in self._process_request(
682
                request_id,
683
                inputs,
684
                sampling_params,
685
                lora_request=lora_request,
686
                trace_headers=trace_headers,
687
        ):
688
            yield LLMEngine.validate_output(output, RequestOutput)
689
690
691

    async def encode(
        self,
692
        inputs: PromptInputs,
693
694
695
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
696
        trace_headers: Optional[Dict[str, str]] = None,
697
698
699
700
701
702
703
704
    ) -> AsyncIterator[EmbeddingRequestOutput]:
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
705
706
707
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
708
709
710
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
711
            trace_headers: OpenTelemetry trace headers.
712
713

        Yields:
714
            The output `EmbeddingRequestOutput` objects from the LLMEngine
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
758
        async for output in self._process_request(
759
                request_id,
760
                inputs,
761
                pooling_params,
762
                lora_request=lora_request,
763
                trace_headers=trace_headers,
764
        ):
765
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
766

767
    async def _process_request(
768
769
        self,
        request_id: str,
770
        inputs: PromptInputs,
771
        params: Union[SamplingParams, PoolingParams],
772
        *,
773
        lora_request: Optional[LoRARequest] = None,
774
        trace_headers: Optional[Dict[str, str]] = None,
775
776
777
778
779
780
781
    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
        """Common logic to process requests with SamplingParams or
        PoolingParams."""
        arrival_time = time.time()

        stream = await self.add_request(
            request_id,
782
            inputs,
783
784
785
            params,
            arrival_time=arrival_time,
            lora_request=lora_request,
786
            trace_headers=trace_headers,
787
        )
788

789
        try:
Antoni Baum's avatar
Antoni Baum committed
790
791
            async for request_output in stream:
                yield request_output
792
        except (Exception, asyncio.CancelledError) as e:
Antoni Baum's avatar
Antoni Baum committed
793
794
            self._abort(request_id)
            raise e
795

Antoni Baum's avatar
Antoni Baum committed
796
797
    async def abort(self, request_id: str) -> None:
        """Abort a request.
798

Antoni Baum's avatar
Antoni Baum committed
799
800
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
801

Antoni Baum's avatar
Antoni Baum committed
802
803
804
        Args:
            request_id: The unique id of the request.
        """
805
806
807
808
809
810
811
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
812
        return self._abort(request_id)
813

Antoni Baum's avatar
Antoni Baum committed
814
    def _abort(self, request_id: str) -> None:
815
816
817
818
819
820
821
822
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
823
824
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
825

826
827
828
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
829
            return await self.engine.get_model_config.remote()  # type: ignore
830
831
832
        else:
            return self.engine.get_model_config()

833
834
835
836
837
838
839
840
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

841
842
843
844
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
845
        if self.engine_use_ray:
846
847
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
848
849
        else:
            self.engine.do_log_stats()
850

851
    async def check_health(self) -> None:
852
853
854
855
856
857
858
859
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
860
                await self.engine.check_health.remote()  # type: ignore
861
862
863
864
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
865
        logger.debug("Health check took %fs", time.perf_counter() - t)
866
867
868
869
870
871
872

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()