async_llm_engine.py 32.5 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
6

7
8
from transformers import PreTrainedTokenizer

9
import vllm.envs as envs
10
from vllm.config import DecodingConfig, ModelConfig
11
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
14
from vllm.executor.ray_utils import initialize_ray_cluster, ray
15
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
19
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.sampling_params import SamplingParams
21
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
22
from vllm.usage.usage_lib import UsageContext
23
24

logger = init_logger(__name__)
25
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
26

Antoni Baum's avatar
Antoni Baum committed
27

28
29
30
31
class AsyncEngineDeadError(RuntimeError):
    pass


32
33
34
def _raise_exception_on_finish(
        task: asyncio.Task, error_callback: Callable[[Exception],
                                                     None]) -> None:
35
36
    msg = ("Task finished unexpectedly. This should never happen! "
           "Please open an issue on Github.")
37
38

    exception = None
39
    try:
40
41
        task.result()
        # NOTE: This will be thrown if task exits normally (which it should not)
42
        raise AsyncEngineDeadError(msg)
43
44
45
46
47
48
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
            msg + " See stack trace above for the actual cause.") from e
49
50


Antoni Baum's avatar
Antoni Baum committed
51
class AsyncStream:
52
53
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
    that can be iterated over asynchronously."""
Antoni Baum's avatar
Antoni Baum committed
54
55
56

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
57
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
58
59
        self._finished = False

60
61
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
62
63
64
65
66
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
67
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
68
69
70
71
72
73
74
75
76
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

77
    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
78
        result = await self._queue.get()
79
        if isinstance(result, Exception):
80
            raise result
Antoni Baum's avatar
Antoni Baum committed
81
82
83
        return result


84
85
86
87
88
89
90
91
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
92
        self.new_requests_event = asyncio.Event()
93
94
95
96

    def __contains__(self, item):
        return item in self._request_streams

97
98
    def __len__(self) -> int:
        return len(self._request_streams)
99
100
101
102
103
104
105
106

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
107
            self.abort_request(request_id)
108
        else:
109
            for rid, stream in self._request_streams.items():
110
                stream.put(exc)
111
                self.abort_request(rid)
112
113

    def process_request_output(self,
114
115
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
116
117
118
119
120
121
122
123
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
124
                logger.info("Finished request %s.", request_id)
125
126
            self.abort_request(request_id)

127
128
129
130
131
132
133
134
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
135
            logger.info("Finished request %s.", request_id)
136
137
        self.abort_request(request_id)

138
139
140
141
142
143
144
145
146
147
148
149
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
150
151
152

        self.new_requests_event.set()

153
154
155
156
157
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
158
            logger.info("Aborted request %s.", request_id)
159
160
161
162
163
164
165
166
167
168

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

169
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
170
171
        """Get the new requests and finished requests to be
        sent to the engine."""
172
        new_requests: List[Dict] = []
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
190

191
    async def wait_for_new_requests(self):
192
193
194
195
196
197
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
198

Antoni Baum's avatar
Antoni Baum committed
199
200
201
202

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

203
204
    async def step_async(
            self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
205
206
207
208
209
210
211
212
213
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
214
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
215

216
217
        if not scheduler_outputs.is_empty():
            # Execute the model.
218
219
220
221
222
223
224
225
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
            )
226
            output = await self.model_executor.execute_model_async(
227
                execute_model_req)
228
229
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
230

231
        request_outputs = self._process_model_outputs(
232
            output, scheduler_outputs.scheduled_seq_groups,
233
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
234

235
        # Log stats.
236
        self.do_log_stats(scheduler_outputs, output)
237

238
239
240
241
242
243
244
245
        if not request_outputs:
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            await self.model_executor.stop_remote_worker_execution_loop_async()

246
247
        return request_outputs

248
    async def process_model_inputs_async(
249
        self,
250
251
        request_id: str,
        inputs: PromptInputs,
252
        lora_request: Optional[LoRARequest] = None,
253
254
255
256
257
258
259
260
261
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = await tokenizer.encode_async(
262
                request_id=request_id,
263
                prompt=inputs["prompt"],
264
                lora_request=lora_request)
265
266
267
268
269
270
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

        return LLMInputs(prompt_token_ids=prompt_token_ids,
                         prompt=inputs.get("prompt"),
                         multi_modal_data=inputs.get("multi_modal_data"))
271
272
273
274

    async def add_request_async(
        self,
        request_id: str,
275
        inputs: PromptInputs,
276
        params: Union[SamplingParams, PoolingParams],
277
278
279
280
281
282
283
284
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
285
286
287
288
289

        processed_inputs = await self.process_model_inputs_async(
            request_id=request_id, inputs=inputs, lora_request=lora_request)

        self._add_processed_request(
290
            request_id=request_id,
291
292
293
294
295
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
        )
296

297
298
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
299

300

301
class AsyncLLMEngine:
302
    """An asynchronous wrapper for :class:`LLMEngine`.
303

304
305
306
307
308
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
309

310
    NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`.
311
312
313
314
315

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
316
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
317
318
            async frontend will be executed in a separate process as the
            model workers.
319
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
320
321
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
322
323
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
324
325
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
326
    """
327

Antoni Baum's avatar
Antoni Baum committed
328
329
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

330
331
332
333
334
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
335
                 max_log_len: Optional[int] = None,
336
                 start_engine_loop: bool = True,
337
                 **kwargs) -> None:
338
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
339
        self.engine_use_ray = engine_use_ray
340
        self.log_requests = log_requests
341
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
342
343
        self.engine = self._init_engine(*args, **kwargs)

344
        self.background_loop: Optional[asyncio.Future] = None
345
346
347
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
348
        self._background_loop_unshielded: Optional[asyncio.Task] = None
349
        self.start_engine_loop = start_engine_loop
350
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
351

352
353
354
        # Lazy initialized fields
        self._request_tracker: RequestTracker

355
    @classmethod
yhu422's avatar
yhu422 committed
356
357
358
359
360
361
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "AsyncLLMEngine":
362
363
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
364
        engine_config = engine_args.create_engine_config()
365
366
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
367

368
        if engine_config.device_config.device_type == "neuron":
369
370
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
371
        elif engine_config.device_config.device_type == "cpu":
372
373
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with the CPU backend.")
374
375
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
376
        elif distributed_executor_backend == "ray":
377
            initialize_ray_cluster(engine_config.parallel_config)
378
379
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
380
381
382
383
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
384
385
386
387
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
388
        engine = cls(
389
            distributed_executor_backend == "ray",
yhu422's avatar
yhu422 committed
390
            engine_args.engine_use_ray,
391
392
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
393
394
395
396
397
398
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            max_log_len=engine_args.max_log_len,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
        )
399
400
        return engine

401
402
    @property
    def is_running(self) -> bool:
403
        return (self.background_loop is not None
404
                and self._background_loop_unshielded is not None
405
406
407
408
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
409
410
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
411
412
413
414
415
416
417
418
419
420
421
422
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
423

424
425
    async def get_tokenizer(self) -> "PreTrainedTokenizer":
        if self.engine_use_ray:
426
            return await self.engine.get_tokenizer.remote()  # type: ignore
427
428
        else:
            return self.engine.get_tokenizer()
429

430
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
431
        """Start the background loop."""
432
433
434
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
435
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
436
            raise RuntimeError("Background loop is already running.")
437
438
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
439
440
441
442

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
443
            partial(_raise_exception_on_finish,
444
                    error_callback=self._error_callback))
445
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
446
447
448

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
449
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
450
            engine_class = self._engine_class
451
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
452
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
453
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
454
455
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
456
457
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
Woosuk Kwon's avatar
Woosuk Kwon committed
458
459
460
461
462
463
            if parallel_config.tensor_parallel_size == 1:
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
464
465
        return engine_class(*args, **kwargs)

466
467
468
469
    async def engine_step(self) -> bool:
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
470
471

        new_requests, finished_requests = (
472
            self._request_tracker.get_new_and_finished_requests())
473
474
475
476

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
477
478
            try:
                if self.engine_use_ray:
479
480
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
481
482
483
484
485
486
487
488
489
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
490
491
492
493

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
494
        if self.engine_use_ray:
495
            request_outputs = await self.engine.step.remote()  # type: ignore
496
        else:
Antoni Baum's avatar
Antoni Baum committed
497
            request_outputs = await self.engine.step_async()
498

Antoni Baum's avatar
Antoni Baum committed
499
        # Put the outputs into the corresponding streams.
500
        for request_output in request_outputs:
501
            self._request_tracker.process_request_output(
502
                request_output, verbose=self.log_requests)
Antoni Baum's avatar
Antoni Baum committed
503

504
505
        return len(request_outputs) > 0

Antoni Baum's avatar
Antoni Baum committed
506
507
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
508
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
509
510
511
512
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
513
        has_requests_in_progress = False
Antoni Baum's avatar
Antoni Baum committed
514
        while True:
515
            if not has_requests_in_progress:
516
                logger.debug("Waiting for new requests...")
517
                await self._request_tracker.wait_for_new_requests()
518
519
520
521
522
523
524
525
526
527
528
529
                logger.debug("Got new requests!")

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
                has_requests_in_progress = await asyncio.wait_for(
                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
530
531
532
533
534
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
535
        inputs: PromptInputs,
536
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
537
        arrival_time: Optional[float] = None,
538
        lora_request: Optional[LoRARequest] = None,
Antoni Baum's avatar
Antoni Baum committed
539
540
    ) -> AsyncStream:
        if self.log_requests:
541
542
543
544
545
546
547
548
549
            if isinstance(inputs, str):
                shortened_prompt = inputs
                shortened_token_ids = None
            else:
                shortened_prompt = inputs.get("prompt")
                shortened_token_ids = inputs.get("prompt_token_ids")

            max_log_len = self.max_log_len
            if max_log_len is not None:
550
                if shortened_prompt is not None:
551
                    shortened_prompt = shortened_prompt[:max_log_len]
552
                if shortened_token_ids is not None:
553
554
                    shortened_token_ids = shortened_token_ids[:max_log_len]

555
556
            logger.info(
                "Received request %s: prompt: %r, "
557
558
559
                "params: %s, prompt_token_ids: %s, "
                "lora_request: %s.", request_id, shortened_prompt, params,
                shortened_token_ids, lora_request)
Antoni Baum's avatar
Antoni Baum committed
560

561
        if not self.is_running:
562
563
564
565
566
567
568
569
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
570

571
572
        if arrival_time is None:
            arrival_time = time.time()
573
574

        if self.engine_use_ray:
575
576
            processed_inputs = await self.engine.process_model_inputs_async \
                .remote(  # type: ignore
577
                    request_id=request_id,
578
579
                    inputs=inputs,
                    lora_request=lora_request)
580
        else:
581
            processed_inputs = await self.engine.process_model_inputs_async(
582
                request_id=request_id,
583
                inputs=inputs,
584
                lora_request=lora_request)
585

586
        stream = self._request_tracker.add_request(
587
            request_id,
588
            inputs=processed_inputs,
589
            params=params,
590
            arrival_time=arrival_time,
591
592
            lora_request=lora_request,
        )
Antoni Baum's avatar
Antoni Baum committed
593
594

        return stream
595

596
    async def generate(
597
        self,
598
        inputs: PromptInputs,
599
600
        sampling_params: SamplingParams,
        request_id: str,
601
        lora_request: Optional[LoRARequest] = None,
602
    ) -> AsyncIterator[RequestOutput]:
603
604
605
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
606
607
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
608
609

        Args:
610
611
612
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
613
614
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
615
            lora_request: LoRA request to use for generation, if any.
616
617

        Yields:
618
619
            The output `RequestOutput` objects from the LLMEngine
            for the request.
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
663
        """
664
        async for output in self._process_request(
665
                request_id,
666
                inputs,
667
                sampling_params,
668
                lora_request=lora_request,
669
        ):
670
            yield LLMEngine.validate_output(output, RequestOutput)
671
672
673

    async def encode(
        self,
674
        inputs: PromptInputs,
675
676
677
678
679
680
681
682
683
684
685
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
    ) -> AsyncIterator[EmbeddingRequestOutput]:
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
686
687
688
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
689
690
691
692
693
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.

        Yields:
694
            The output `EmbeddingRequestOutput` objects from the LLMEngine
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
738
        async for output in self._process_request(
739
                request_id,
740
                inputs,
741
                pooling_params,
742
                lora_request=lora_request,
743
        ):
744
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
745

746
    async def _process_request(
747
748
        self,
        request_id: str,
749
        inputs: PromptInputs,
750
        params: Union[SamplingParams, PoolingParams],
751
        *,
752
753
754
755
756
757
758
759
        lora_request: Optional[LoRARequest] = None,
    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
        """Common logic to process requests with SamplingParams or
        PoolingParams."""
        arrival_time = time.time()

        stream = await self.add_request(
            request_id,
760
            inputs,
761
762
763
764
            params,
            arrival_time=arrival_time,
            lora_request=lora_request,
        )
765

766
        try:
Antoni Baum's avatar
Antoni Baum committed
767
768
            async for request_output in stream:
                yield request_output
769
        except (Exception, asyncio.CancelledError) as e:
Antoni Baum's avatar
Antoni Baum committed
770
771
            self._abort(request_id)
            raise e
772

Antoni Baum's avatar
Antoni Baum committed
773
774
    async def abort(self, request_id: str) -> None:
        """Abort a request.
775

Antoni Baum's avatar
Antoni Baum committed
776
777
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
778

Antoni Baum's avatar
Antoni Baum committed
779
780
781
        Args:
            request_id: The unique id of the request.
        """
782
783
784
785
786
787
788
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
789
        return self._abort(request_id)
790

Antoni Baum's avatar
Antoni Baum committed
791
    def _abort(self, request_id: str) -> None:
792
793
794
795
796
797
798
799
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
800
801
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
802

803
804
805
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
806
            return await self.engine.get_model_config.remote()  # type: ignore
807
808
809
        else:
            return self.engine.get_model_config()

810
811
812
813
814
815
816
817
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

818
819
820
821
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
822
        if self.engine_use_ray:
823
824
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
825
826
        else:
            self.engine.do_log_stats()
827

828
    async def check_health(self) -> None:
829
830
831
832
833
834
835
836
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
837
                await self.engine.check_health.remote()  # type: ignore
838
839
840
841
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
842
        logger.debug("Health check took %fs", time.perf_counter() - t)