async_llm_engine.py 33.1 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
6

7
8
from transformers import PreTrainedTokenizer

9
import vllm.envs as envs
10
from vllm.config import DecodingConfig, ModelConfig
11
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
14
from vllm.executor.ray_utils import initialize_ray_cluster, ray
15
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
19
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.sampling_params import SamplingParams
21
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
22
from vllm.usage.usage_lib import UsageContext
23
24

logger = init_logger(__name__)
25
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
26

Antoni Baum's avatar
Antoni Baum committed
27

28
29
30
31
class AsyncEngineDeadError(RuntimeError):
    pass


32
33
34
35
36
37
38
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
39
40

    exception = None
41
    try:
42
43
44
45
46
47
48
49
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
50
51
52
53
54
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
55
56
57
            "Task finished unexpectedly. This should never happen! "
            "Please open an issue on Github. See stack trace above for the"
            "actual cause.") from e
58
59


Antoni Baum's avatar
Antoni Baum committed
60
class AsyncStream:
61
62
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
    that can be iterated over asynchronously."""
Antoni Baum's avatar
Antoni Baum committed
63
64
65

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
66
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
67
68
        self._finished = False

69
70
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
71
72
73
74
75
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
76
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
77
78
79
80
81
82
83
84
85
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

86
    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
87
        result = await self._queue.get()
88
        if isinstance(result, Exception):
89
            raise result
Antoni Baum's avatar
Antoni Baum committed
90
91
92
        return result


93
94
95
96
97
98
99
100
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
101
        self.new_requests_event = asyncio.Event()
102
103
104
105

    def __contains__(self, item):
        return item in self._request_streams

106
107
    def __len__(self) -> int:
        return len(self._request_streams)
108
109
110
111
112
113
114
115

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
116
            self.abort_request(request_id)
117
        else:
118
            for rid, stream in self._request_streams.items():
119
                stream.put(exc)
120
                self.abort_request(rid)
121
122

    def process_request_output(self,
123
124
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
125
126
127
128
129
130
131
132
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
133
                logger.info("Finished request %s.", request_id)
134
135
            self.abort_request(request_id)

136
137
138
139
140
141
142
143
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
144
            logger.info("Finished request %s.", request_id)
145
146
        self.abort_request(request_id)

147
148
149
150
151
152
153
154
155
156
157
158
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
159
160
161

        self.new_requests_event.set()

162
163
164
165
166
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
167
            logger.info("Aborted request %s.", request_id)
168
169
170
171
172
173
174
175
176
177

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

178
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
179
180
        """Get the new requests and finished requests to be
        sent to the engine."""
181
        new_requests: List[Dict] = []
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
199

200
    async def wait_for_new_requests(self):
201
202
203
204
205
206
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
207

Antoni Baum's avatar
Antoni Baum committed
208
209
210
211

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

212
213
    async def step_async(
            self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
214
215
216
217
218
219
220
221
222
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
223
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
224

225
226
        if not scheduler_outputs.is_empty():
            # Execute the model.
227
228
229
230
231
232
233
234
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
            )
235
            output = await self.model_executor.execute_model_async(
236
                execute_model_req)
237
238
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
239

240
        request_outputs = self._process_model_outputs(
241
            output, scheduler_outputs.scheduled_seq_groups,
242
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
243

244
        # Log stats.
245
        self.do_log_stats(scheduler_outputs, output)
246

247
248
249
250
251
252
253
254
        if not request_outputs:
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            await self.model_executor.stop_remote_worker_execution_loop_async()

255
256
        return request_outputs

257
    async def process_model_inputs_async(
258
        self,
259
260
        request_id: str,
        inputs: PromptInputs,
261
        lora_request: Optional[LoRARequest] = None,
262
263
264
265
266
267
268
269
270
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = await tokenizer.encode_async(
271
                request_id=request_id,
272
                prompt=inputs["prompt"],
273
                lora_request=lora_request)
274
275
276
277
278
279
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

        return LLMInputs(prompt_token_ids=prompt_token_ids,
                         prompt=inputs.get("prompt"),
                         multi_modal_data=inputs.get("multi_modal_data"))
280
281
282
283

    async def add_request_async(
        self,
        request_id: str,
284
        inputs: PromptInputs,
285
        params: Union[SamplingParams, PoolingParams],
286
287
288
289
290
291
292
293
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
294
295
296
297
298

        processed_inputs = await self.process_model_inputs_async(
            request_id=request_id, inputs=inputs, lora_request=lora_request)

        self._add_processed_request(
299
            request_id=request_id,
300
301
302
303
304
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
        )
305

306
307
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
308

309

310
class AsyncLLMEngine:
311
    """An asynchronous wrapper for :class:`LLMEngine`.
312

313
314
315
316
317
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
318
319
320
321
322

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
323
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
324
325
            async frontend will be executed in a separate process as the
            model workers.
326
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
327
328
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
329
330
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
331
332
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
333
    """
334

Antoni Baum's avatar
Antoni Baum committed
335
336
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

337
338
339
340
341
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
342
                 max_log_len: Optional[int] = None,
343
                 start_engine_loop: bool = True,
344
                 **kwargs) -> None:
345
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
346
        self.engine_use_ray = engine_use_ray
347
        self.log_requests = log_requests
348
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
349
350
        self.engine = self._init_engine(*args, **kwargs)

351
        self.background_loop: Optional[asyncio.Future] = None
352
353
354
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
355
        self._background_loop_unshielded: Optional[asyncio.Task] = None
356
        self.start_engine_loop = start_engine_loop
357
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
358

359
360
361
        # Lazy initialized fields
        self._request_tracker: RequestTracker

362
    @classmethod
yhu422's avatar
yhu422 committed
363
364
365
366
367
368
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "AsyncLLMEngine":
369
370
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
371
        engine_config = engine_args.create_engine_config()
372
373
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
374

375
        if engine_config.device_config.device_type == "neuron":
376
377
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
378
379
380
        elif engine_config.device_config.device_type == "tpu":
            from vllm.executor.tpu_executor import TPUExecutorAsync
            executor_class = TPUExecutorAsync
381
        elif engine_config.device_config.device_type == "cpu":
382
383
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with the CPU backend.")
384
385
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
386
387
388
389
390
391
392
393
394
395
396
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
397
        elif distributed_executor_backend == "ray":
398
            initialize_ray_cluster(engine_config.parallel_config)
399
400
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
401
402
403
404
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
405
406
407
408
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
409
        engine = cls(
410
            distributed_executor_backend == "ray",
yhu422's avatar
yhu422 committed
411
            engine_args.engine_use_ray,
412
413
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
414
415
416
417
418
419
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            max_log_len=engine_args.max_log_len,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
        )
420
421
        return engine

422
423
    @property
    def is_running(self) -> bool:
424
        return (self.background_loop is not None
425
                and self._background_loop_unshielded is not None
426
427
428
429
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
430
431
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
432
433
434
435
436
437
438
439
440
441
442
443
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
444

445
446
    async def get_tokenizer(self) -> "PreTrainedTokenizer":
        if self.engine_use_ray:
447
            return await self.engine.get_tokenizer.remote()  # type: ignore
448
449
        else:
            return self.engine.get_tokenizer()
450

451
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
452
        """Start the background loop."""
453
454
455
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
456
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
457
            raise RuntimeError("Background loop is already running.")
458
459
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
460
461
462
463

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
464
            partial(_log_task_completion, error_callback=self._error_callback))
465
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
466
467
468

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
469
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
470
            engine_class = self._engine_class
471
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
472
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
473
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
474
475
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
476
477
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
Woosuk Kwon's avatar
Woosuk Kwon committed
478
479
480
481
482
483
            if parallel_config.tensor_parallel_size == 1:
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
484
485
        return engine_class(*args, **kwargs)

486
487
488
489
    async def engine_step(self) -> bool:
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
490
491

        new_requests, finished_requests = (
492
            self._request_tracker.get_new_and_finished_requests())
493
494
495
496

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
497
498
            try:
                if self.engine_use_ray:
499
500
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
501
502
503
504
505
506
507
508
509
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
510
511
512
513

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
514
        if self.engine_use_ray:
515
            request_outputs = await self.engine.step.remote()  # type: ignore
516
        else:
Antoni Baum's avatar
Antoni Baum committed
517
            request_outputs = await self.engine.step_async()
518

Antoni Baum's avatar
Antoni Baum committed
519
        # Put the outputs into the corresponding streams.
520
        for request_output in request_outputs:
521
            self._request_tracker.process_request_output(
522
                request_output, verbose=self.log_requests)
Antoni Baum's avatar
Antoni Baum committed
523

524
525
        return len(request_outputs) > 0

Antoni Baum's avatar
Antoni Baum committed
526
527
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
528
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
529
530
531
532
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
533
        has_requests_in_progress = False
Antoni Baum's avatar
Antoni Baum committed
534
        while True:
535
            if not has_requests_in_progress:
536
                logger.debug("Waiting for new requests...")
537
                await self._request_tracker.wait_for_new_requests()
538
539
540
541
542
543
544
545
546
547
548
549
                logger.debug("Got new requests!")

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
                has_requests_in_progress = await asyncio.wait_for(
                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
550
551
552
553
554
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
555
        inputs: PromptInputs,
556
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
557
        arrival_time: Optional[float] = None,
558
        lora_request: Optional[LoRARequest] = None,
Antoni Baum's avatar
Antoni Baum committed
559
560
    ) -> AsyncStream:
        if self.log_requests:
561
562
563
564
565
566
567
568
569
            if isinstance(inputs, str):
                shortened_prompt = inputs
                shortened_token_ids = None
            else:
                shortened_prompt = inputs.get("prompt")
                shortened_token_ids = inputs.get("prompt_token_ids")

            max_log_len = self.max_log_len
            if max_log_len is not None:
570
                if shortened_prompt is not None:
571
                    shortened_prompt = shortened_prompt[:max_log_len]
572
                if shortened_token_ids is not None:
573
574
                    shortened_token_ids = shortened_token_ids[:max_log_len]

575
576
            logger.info(
                "Received request %s: prompt: %r, "
577
578
579
                "params: %s, prompt_token_ids: %s, "
                "lora_request: %s.", request_id, shortened_prompt, params,
                shortened_token_ids, lora_request)
Antoni Baum's avatar
Antoni Baum committed
580

581
        if not self.is_running:
582
583
584
585
586
587
588
589
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
590

591
592
        if arrival_time is None:
            arrival_time = time.time()
593

594
        stream = self._request_tracker.add_request(
595
            request_id,
596
            inputs=inputs,
597
            params=params,
598
            arrival_time=arrival_time,
599
600
            lora_request=lora_request,
        )
Antoni Baum's avatar
Antoni Baum committed
601
602

        return stream
603

604
    async def generate(
605
        self,
606
        inputs: PromptInputs,
607
608
        sampling_params: SamplingParams,
        request_id: str,
609
        lora_request: Optional[LoRARequest] = None,
610
    ) -> AsyncIterator[RequestOutput]:
611
612
613
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
614
615
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
616
617

        Args:
618
619
620
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
621
622
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
623
            lora_request: LoRA request to use for generation, if any.
624
625

        Yields:
626
627
            The output `RequestOutput` objects from the LLMEngine
            for the request.
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
671
        """
672
        async for output in self._process_request(
673
                request_id,
674
                inputs,
675
                sampling_params,
676
                lora_request=lora_request,
677
        ):
678
            yield LLMEngine.validate_output(output, RequestOutput)
679
680
681

    async def encode(
        self,
682
        inputs: PromptInputs,
683
684
685
686
687
688
689
690
691
692
693
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
    ) -> AsyncIterator[EmbeddingRequestOutput]:
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
694
695
696
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
697
698
699
700
701
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.

        Yields:
702
            The output `EmbeddingRequestOutput` objects from the LLMEngine
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
746
        async for output in self._process_request(
747
                request_id,
748
                inputs,
749
                pooling_params,
750
                lora_request=lora_request,
751
        ):
752
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
753

754
    async def _process_request(
755
756
        self,
        request_id: str,
757
        inputs: PromptInputs,
758
        params: Union[SamplingParams, PoolingParams],
759
        *,
760
761
762
763
764
765
766
767
        lora_request: Optional[LoRARequest] = None,
    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
        """Common logic to process requests with SamplingParams or
        PoolingParams."""
        arrival_time = time.time()

        stream = await self.add_request(
            request_id,
768
            inputs,
769
770
771
772
            params,
            arrival_time=arrival_time,
            lora_request=lora_request,
        )
773

774
        try:
Antoni Baum's avatar
Antoni Baum committed
775
776
            async for request_output in stream:
                yield request_output
777
        except (Exception, asyncio.CancelledError) as e:
Antoni Baum's avatar
Antoni Baum committed
778
779
            self._abort(request_id)
            raise e
780

Antoni Baum's avatar
Antoni Baum committed
781
782
    async def abort(self, request_id: str) -> None:
        """Abort a request.
783

Antoni Baum's avatar
Antoni Baum committed
784
785
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
786

Antoni Baum's avatar
Antoni Baum committed
787
788
789
        Args:
            request_id: The unique id of the request.
        """
790
791
792
793
794
795
796
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
797
        return self._abort(request_id)
798

Antoni Baum's avatar
Antoni Baum committed
799
    def _abort(self, request_id: str) -> None:
800
801
802
803
804
805
806
807
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
808
809
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
810

811
812
813
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
814
            return await self.engine.get_model_config.remote()  # type: ignore
815
816
817
        else:
            return self.engine.get_model_config()

818
819
820
821
822
823
824
825
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

826
827
828
829
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
830
        if self.engine_use_ray:
831
832
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
833
834
        else:
            self.engine.do_log_stats()
835

836
    async def check_health(self) -> None:
837
838
839
840
841
842
843
844
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
845
                await self.engine.check_health.remote()  # type: ignore
846
847
848
849
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
850
        logger.debug("Health check took %fs", time.perf_counter() - t)