async_llm_engine.py 34.1 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
6

7
8
from transformers import PreTrainedTokenizer

9
import vllm.envs as envs
10
from vllm.config import DecodingConfig, ModelConfig
11
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.engine.arg_utils import AsyncEngineArgs
13
from vllm.engine.async_timeout import asyncio_timeout
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.engine.llm_engine import LLMEngine
15
from vllm.executor.ray_utils import initialize_ray_cluster, ray
16
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.logger import init_logger
18
from vllm.lora.request import LoRARequest
19
20
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.sampling_params import SamplingParams
22
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
23
from vllm.usage.usage_lib import UsageContext
24
25

logger = init_logger(__name__)
26
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
27

Antoni Baum's avatar
Antoni Baum committed
28

29
30
31
32
class AsyncEngineDeadError(RuntimeError):
    pass


33
34
35
36
37
38
39
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
40
41

    exception = None
42
    try:
43
44
45
46
47
48
49
50
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
51
52
53
54
55
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
56
57
58
            "Task finished unexpectedly. This should never happen! "
            "Please open an issue on Github. See stack trace above for the"
            "actual cause.") from e
59
60


Antoni Baum's avatar
Antoni Baum committed
61
class AsyncStream:
62
63
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
    that can be iterated over asynchronously."""
Antoni Baum's avatar
Antoni Baum committed
64
65
66

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
67
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
68
69
        self._finished = False

70
71
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
72
73
74
75
76
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
77
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
78
79
80
81
82
83
84
85
86
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

87
    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
88
        result = await self._queue.get()
89
        if isinstance(result, Exception):
90
            raise result
Antoni Baum's avatar
Antoni Baum committed
91
92
93
        return result


94
95
96
97
98
99
100
101
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
102
        self.new_requests_event = asyncio.Event()
103
104
105
106

    def __contains__(self, item):
        return item in self._request_streams

107
108
    def __len__(self) -> int:
        return len(self._request_streams)
109
110
111
112
113
114
115
116

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
117
            self.abort_request(request_id)
118
        else:
119
            for rid, stream in self._request_streams.items():
120
                stream.put(exc)
121
                self.abort_request(rid)
122
123

    def process_request_output(self,
124
125
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
126
127
128
129
130
131
132
133
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
134
                logger.info("Finished request %s.", request_id)
135
136
            self.abort_request(request_id)

137
138
139
140
141
142
143
144
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
145
            logger.info("Finished request %s.", request_id)
146
147
        self.abort_request(request_id)

148
149
150
151
152
153
154
155
156
157
158
159
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
160
161
162

        self.new_requests_event.set()

163
164
165
166
167
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
168
            logger.info("Aborted request %s.", request_id)
169
170
171
172
173
174
175
176
177
178

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

179
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
180
181
        """Get the new requests and finished requests to be
        sent to the engine."""
182
        new_requests: List[Dict] = []
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
200

201
    async def wait_for_new_requests(self):
202
203
204
205
206
207
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
208

Antoni Baum's avatar
Antoni Baum committed
209
210
211
212

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

213
214
    async def step_async(
            self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
215
216
217
218
219
220
221
222
223
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
224
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
225

226
227
        if not scheduler_outputs.is_empty():
            # Execute the model.
228
229
230
231
232
233
234
235
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
            )
236
            output = await self.model_executor.execute_model_async(
237
                execute_model_req)
238
239
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
240

241
        request_outputs = self._process_model_outputs(
242
            output, scheduler_outputs.scheduled_seq_groups,
243
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
244

245
        # Log stats.
246
        self.do_log_stats(scheduler_outputs, output)
247

248
249
250
        # Tracing
        self.do_tracing(scheduler_outputs)

251
252
253
254
255
256
257
258
        if not request_outputs:
            # Stop the execute model loop in parallel workers until there are
            # more requests to process. This avoids waiting indefinitely in
            # torch.distributed ops which may otherwise timeout, and unblocks
            # the RPC thread in the workers so that they can process any other
            # queued control plane messages, such as add/remove lora adapters.
            await self.model_executor.stop_remote_worker_execution_loop_async()

259
260
        return request_outputs

261
    async def process_model_inputs_async(
262
        self,
263
264
        request_id: str,
        inputs: PromptInputs,
265
        lora_request: Optional[LoRARequest] = None,
266
267
268
269
270
271
272
273
274
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = await tokenizer.encode_async(
275
                request_id=request_id,
276
                prompt=inputs["prompt"],
277
                lora_request=lora_request)
278
279
280
281
282
283
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

        return LLMInputs(prompt_token_ids=prompt_token_ids,
                         prompt=inputs.get("prompt"),
                         multi_modal_data=inputs.get("multi_modal_data"))
284
285
286
287

    async def add_request_async(
        self,
        request_id: str,
288
        inputs: PromptInputs,
289
        params: Union[SamplingParams, PoolingParams],
290
291
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
292
        trace_headers: Optional[Dict[str, str]] = None,
293
294
295
296
297
298
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
299
300
301
302
303

        processed_inputs = await self.process_model_inputs_async(
            request_id=request_id, inputs=inputs, lora_request=lora_request)

        self._add_processed_request(
304
            request_id=request_id,
305
306
307
308
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
309
            trace_headers=trace_headers,
310
        )
311

312
    async def check_health_async(self) -> None:
313
314
        if self.tokenizer:
            self.tokenizer.check_health()
315
        self.model_executor.check_health()
316

317

318
class AsyncLLMEngine:
319
    """An asynchronous wrapper for :class:`LLMEngine`.
320

321
322
323
324
325
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
326
327
328
329
330

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
331
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
332
333
            async frontend will be executed in a separate process as the
            model workers.
334
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
335
336
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
337
338
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
339
340
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
341
    """
342

Antoni Baum's avatar
Antoni Baum committed
343
344
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

345
346
347
348
349
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
350
                 max_log_len: Optional[int] = None,
351
                 start_engine_loop: bool = True,
352
                 **kwargs) -> None:
353
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
354
        self.engine_use_ray = engine_use_ray
355
        self.log_requests = log_requests
356
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
357
358
        self.engine = self._init_engine(*args, **kwargs)

359
        self.background_loop: Optional[asyncio.Future] = None
360
361
362
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
363
        self._background_loop_unshielded: Optional[asyncio.Task] = None
364
        self.start_engine_loop = start_engine_loop
365
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
366

367
368
369
        # Lazy initialized fields
        self._request_tracker: RequestTracker

370
    @classmethod
yhu422's avatar
yhu422 committed
371
372
373
374
375
376
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> "AsyncLLMEngine":
377
378
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
379
        engine_config = engine_args.create_engine_config()
380
381
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
382

383
        if engine_config.device_config.device_type == "neuron":
384
385
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
386
387
388
        elif engine_config.device_config.device_type == "tpu":
            from vllm.executor.tpu_executor import TPUExecutorAsync
            executor_class = TPUExecutorAsync
389
        elif engine_config.device_config.device_type == "cpu":
390
391
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with the CPU backend.")
392
393
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
394
395
396
397
398
399
400
401
402
403
404
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
405
        elif distributed_executor_backend == "ray":
406
            initialize_ray_cluster(engine_config.parallel_config)
407
408
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
409
410
411
412
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
413
414
415
416
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
417
        engine = cls(
418
            distributed_executor_backend == "ray",
yhu422's avatar
yhu422 committed
419
            engine_args.engine_use_ray,
420
421
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
422
423
424
425
426
427
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            max_log_len=engine_args.max_log_len,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
        )
428
429
        return engine

430
431
    @property
    def is_running(self) -> bool:
432
        return (self.background_loop is not None
433
                and self._background_loop_unshielded is not None
434
435
436
437
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
438
439
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
440
441
442
443
444
445
446
447
448
449
450
451
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
452

453
454
    async def get_tokenizer(self) -> "PreTrainedTokenizer":
        if self.engine_use_ray:
455
            return await self.engine.get_tokenizer.remote()  # type: ignore
456
457
        else:
            return self.engine.get_tokenizer()
458

459
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
460
        """Start the background loop."""
461
462
463
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
464
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
465
            raise RuntimeError("Background loop is already running.")
466
467
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
468
469
470
471

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
472
            partial(_log_task_completion, error_callback=self._error_callback))
473
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
474
475
476

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
477
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
478
            engine_class = self._engine_class
479
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
480
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
481
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
482
483
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
484
485
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
Woosuk Kwon's avatar
Woosuk Kwon committed
486
487
488
489
490
491
            if parallel_config.tensor_parallel_size == 1:
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
492
493
        return engine_class(*args, **kwargs)

494
495
496
497
    async def engine_step(self) -> bool:
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
498
499

        new_requests, finished_requests = (
500
            self._request_tracker.get_new_and_finished_requests())
501
502
503
504

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
505
506
            try:
                if self.engine_use_ray:
507
508
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
509
510
511
512
513
514
515
516
517
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
518
519
520
521

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
522
        if self.engine_use_ray:
523
            request_outputs = await self.engine.step.remote()  # type: ignore
524
        else:
Antoni Baum's avatar
Antoni Baum committed
525
            request_outputs = await self.engine.step_async()
526

Antoni Baum's avatar
Antoni Baum committed
527
        # Put the outputs into the corresponding streams.
528
        for request_output in request_outputs:
529
            self._request_tracker.process_request_output(
530
                request_output, verbose=self.log_requests)
Antoni Baum's avatar
Antoni Baum committed
531

532
533
        return len(request_outputs) > 0

Antoni Baum's avatar
Antoni Baum committed
534
535
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
536
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
537
538
539
540
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
541
        has_requests_in_progress = False
Antoni Baum's avatar
Antoni Baum committed
542
        while True:
543
            if not has_requests_in_progress:
544
                logger.debug("Waiting for new requests...")
545
                await self._request_tracker.wait_for_new_requests()
546
547
548
549
550
                logger.debug("Got new requests!")

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
551
552
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
                    has_requests_in_progress = await self.engine_step()
553
554
555
556
557
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
558
559
560
561
562
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
563
        inputs: PromptInputs,
564
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
565
        arrival_time: Optional[float] = None,
566
        lora_request: Optional[LoRARequest] = None,
567
        trace_headers: Optional[Dict[str, str]] = None,
Antoni Baum's avatar
Antoni Baum committed
568
569
    ) -> AsyncStream:
        if self.log_requests:
570
571
572
573
574
575
576
577
578
            if isinstance(inputs, str):
                shortened_prompt = inputs
                shortened_token_ids = None
            else:
                shortened_prompt = inputs.get("prompt")
                shortened_token_ids = inputs.get("prompt_token_ids")

            max_log_len = self.max_log_len
            if max_log_len is not None:
579
                if shortened_prompt is not None:
580
                    shortened_prompt = shortened_prompt[:max_log_len]
581
                if shortened_token_ids is not None:
582
583
                    shortened_token_ids = shortened_token_ids[:max_log_len]

584
585
            logger.info(
                "Received request %s: prompt: %r, "
586
587
588
                "params: %s, prompt_token_ids: %s, "
                "lora_request: %s.", request_id, shortened_prompt, params,
                shortened_token_ids, lora_request)
Antoni Baum's avatar
Antoni Baum committed
589

590
        if not self.is_running:
591
592
593
594
595
596
597
598
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
599

600
601
        if arrival_time is None:
            arrival_time = time.time()
602

603
        stream = self._request_tracker.add_request(
604
            request_id,
605
            inputs=inputs,
606
            params=params,
607
            arrival_time=arrival_time,
608
            lora_request=lora_request,
609
            trace_headers=trace_headers,
610
        )
Antoni Baum's avatar
Antoni Baum committed
611
612

        return stream
613

614
    async def generate(
615
        self,
616
        inputs: PromptInputs,
617
618
        sampling_params: SamplingParams,
        request_id: str,
619
        lora_request: Optional[LoRARequest] = None,
620
        trace_headers: Optional[Dict[str, str]] = None,
621
    ) -> AsyncIterator[RequestOutput]:
622
623
624
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
625
626
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
627
628

        Args:
629
630
631
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
632
633
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
634
            lora_request: LoRA request to use for generation, if any.
635
            trace_headers: OpenTelemetry trace headers.
636
637

        Yields:
638
639
            The output `RequestOutput` objects from the LLMEngine
            for the request.
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
683
        """
684
        async for output in self._process_request(
685
                request_id,
686
                inputs,
687
                sampling_params,
688
                lora_request=lora_request,
689
                trace_headers=trace_headers,
690
        ):
691
            yield LLMEngine.validate_output(output, RequestOutput)
692
693
694

    async def encode(
        self,
695
        inputs: PromptInputs,
696
697
698
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
699
        trace_headers: Optional[Dict[str, str]] = None,
700
701
702
703
704
705
706
707
    ) -> AsyncIterator[EmbeddingRequestOutput]:
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
708
709
710
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
711
712
713
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
714
            trace_headers: OpenTelemetry trace headers.
715
716

        Yields:
717
            The output `EmbeddingRequestOutput` objects from the LLMEngine
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
761
        async for output in self._process_request(
762
                request_id,
763
                inputs,
764
                pooling_params,
765
                lora_request=lora_request,
766
                trace_headers=trace_headers,
767
        ):
768
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
769

770
    async def _process_request(
771
772
        self,
        request_id: str,
773
        inputs: PromptInputs,
774
        params: Union[SamplingParams, PoolingParams],
775
        *,
776
        lora_request: Optional[LoRARequest] = None,
777
        trace_headers: Optional[Dict[str, str]] = None,
778
779
780
781
782
783
784
    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
        """Common logic to process requests with SamplingParams or
        PoolingParams."""
        arrival_time = time.time()

        stream = await self.add_request(
            request_id,
785
            inputs,
786
787
788
            params,
            arrival_time=arrival_time,
            lora_request=lora_request,
789
            trace_headers=trace_headers,
790
        )
791

792
        try:
Antoni Baum's avatar
Antoni Baum committed
793
794
            async for request_output in stream:
                yield request_output
795
        except (Exception, asyncio.CancelledError) as e:
Antoni Baum's avatar
Antoni Baum committed
796
797
            self._abort(request_id)
            raise e
798

Antoni Baum's avatar
Antoni Baum committed
799
800
    async def abort(self, request_id: str) -> None:
        """Abort a request.
801

Antoni Baum's avatar
Antoni Baum committed
802
803
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
804

Antoni Baum's avatar
Antoni Baum committed
805
806
807
        Args:
            request_id: The unique id of the request.
        """
808
809
810
811
812
813
814
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
815
        return self._abort(request_id)
816

Antoni Baum's avatar
Antoni Baum committed
817
    def _abort(self, request_id: str) -> None:
818
819
820
821
822
823
824
825
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
826
827
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
828

829
830
831
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
832
            return await self.engine.get_model_config.remote()  # type: ignore
833
834
835
        else:
            return self.engine.get_model_config()

836
837
838
839
840
841
842
843
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

844
845
846
847
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
848
        if self.engine_use_ray:
849
850
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
851
852
        else:
            self.engine.do_log_stats()
853

854
    async def check_health(self) -> None:
855
856
857
858
859
860
861
862
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
863
                await self.engine.check_health.remote()  # type: ignore
864
865
866
867
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
868
        logger.debug("Health check took %fs", time.perf_counter() - t)
869
870
871
872
873
874
875

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()