async_llm_engine.py 39.1 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
6

7
8
from transformers import PreTrainedTokenizer

9
import vllm.envs as envs
10
from vllm.config import DecodingConfig, ModelConfig
11
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.engine.arg_utils import AsyncEngineArgs
13
from vllm.engine.async_timeout import asyncio_timeout
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.engine.llm_engine import LLMEngine
15
from vllm.engine.metrics import StatLoggerBase
16
from vllm.executor.ray_utils import initialize_ray_cluster, ray
17
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
21
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
22
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
23
from vllm.sampling_params import SamplingParams
24
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
25
from vllm.usage.usage_lib import UsageContext
26
27

logger = init_logger(__name__)
28
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
29

Antoni Baum's avatar
Antoni Baum committed
30

31
32
33
34
class AsyncEngineDeadError(RuntimeError):
    pass


35
36
37
38
39
40
41
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
42
43

    exception = None
44
    try:
45
46
47
48
49
50
51
52
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
53
54
55
56
57
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
58
59
60
            "Task finished unexpectedly. This should never happen! "
            "Please open an issue on Github. See stack trace above for the"
            "actual cause.") from e
61
62


Antoni Baum's avatar
Antoni Baum committed
63
class AsyncStream:
64
65
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
    that can be iterated over asynchronously."""
Antoni Baum's avatar
Antoni Baum committed
66
67
68

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
69
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
70
71
        self._finished = False

72
73
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
74
75
76
77
78
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
79
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
80
81
82
83
84
85
86
87
88
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

89
    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
90
        result = await self._queue.get()
91
        if isinstance(result, Exception):
92
            raise result
Antoni Baum's avatar
Antoni Baum committed
93
94
95
        return result


96
97
98
99
100
101
102
103
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
104
        self.new_requests_event = asyncio.Event()
105
106
107
108

    def __contains__(self, item):
        return item in self._request_streams

109
110
    def __len__(self) -> int:
        return len(self._request_streams)
111
112
113
114
115
116
117
118

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
119
            self.abort_request(request_id)
120
        else:
121
            for rid, stream in self._request_streams.items():
122
                stream.put(exc)
123
                self.abort_request(rid)
124
125

    def process_request_output(self,
126
127
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
128
129
130
131
132
133
134
135
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
136
                logger.info("Finished request %s.", request_id)
137
138
            self.abort_request(request_id)

139
140
141
142
143
144
145
146
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
147
            logger.info("Finished request %s.", request_id)
148
149
        self.abort_request(request_id)

150
151
152
153
154
155
156
157
158
159
160
161
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
162
163
164

        self.new_requests_event.set()

165
166
167
168
169
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
170
            logger.info("Aborted request %s.", request_id)
171
172
173
174
175
176
177
178
179
180

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

181
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
182
183
        """Get the new requests and finished requests to be
        sent to the engine."""
184
        new_requests: List[Dict] = []
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
202

203
    async def wait_for_new_requests(self):
204
205
206
207
208
209
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
210

Antoni Baum's avatar
Antoni Baum committed
211
212
213
214

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

215
    async def step_async(
216
217
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
218
219
220
221
222
223
224
225
226
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
227
228
        seq_group_metadata_list, scheduler_outputs = self.scheduler[
            virtual_engine].schedule()
Antoni Baum's avatar
Antoni Baum committed
229

230
231
        if not scheduler_outputs.is_empty():
            # Execute the model.
232
233
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
234
235
236
237
238
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
239
                virtual_engine=virtual_engine,
240
241
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
Mor Zusman's avatar
Mor Zusman committed
242
                finished_requests_ids=finished_requests_ids)
243
            output = await self.model_executor.execute_model_async(
244
                execute_model_req)
245
246
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
247

248
        request_outputs = self._process_model_outputs(
249
            output, scheduler_outputs.scheduled_seq_groups,
250
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
251

252
        # Log stats.
253
        self.do_log_stats(scheduler_outputs, output)
254

255
256
257
        # Tracing
        self.do_tracing(scheduler_outputs)

258
259
        return request_outputs

260
261
262
263
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

264
    async def process_model_inputs_async(
265
        self,
266
267
        request_id: str,
        inputs: PromptInputs,
268
        lora_request: Optional[LoRARequest] = None,
269
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
270
271
272
273
274
275
276
277
278
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = await tokenizer.encode_async(
279
                request_id=request_id,
280
                prompt=inputs["prompt"],
281
                lora_request=lora_request)
282
283
284
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

285
286
287
288
289
290
        if prompt_adapter_request:
            prompt_token_ids = [
                0
            ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
                prompt_token_ids

291
292
293
294
295
        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
                               prompt=inputs.get("prompt"),
                               multi_modal_data=inputs.get("multi_modal_data"))

        return self.input_processor(llm_inputs)
296
297

    async def add_request_async(
298
299
300
301
302
303
304
305
            self,
            request_id: str,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Dict[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
306
307
308
309
310
311
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
312
313

        processed_inputs = await self.process_model_inputs_async(
314
315
316
317
            request_id=request_id,
            inputs=inputs,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
318
319

        self._add_processed_request(
320
            request_id=request_id,
321
322
323
324
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
325
            prompt_adapter_request=prompt_adapter_request,
326
            trace_headers=trace_headers,
327
        )
328

329
    async def check_health_async(self) -> None:
330
331
        if self.tokenizer:
            self.tokenizer.check_health()
332
        self.model_executor.check_health()
333

334

335
class AsyncLLMEngine:
336
    """An asynchronous wrapper for :class:`LLMEngine`.
337

338
339
340
341
342
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
343
344
345
346
347

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
348
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
349
350
            async frontend will be executed in a separate process as the
            model workers.
351
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
352
353
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
354
355
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
356
357
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
358
    """
359

Antoni Baum's avatar
Antoni Baum committed
360
361
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

362
363
364
365
366
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
367
                 max_log_len: Optional[int] = None,
368
                 start_engine_loop: bool = True,
369
                 **kwargs) -> None:
370
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
371
        self.engine_use_ray = engine_use_ray
372
        self.log_requests = log_requests
373
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
374
375
        self.engine = self._init_engine(*args, **kwargs)

376
        self.background_loop: Optional[asyncio.Future] = None
377
378
379
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
380
        self._background_loop_unshielded: Optional[asyncio.Task] = None
381
        self.start_engine_loop = start_engine_loop
382
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
383

384
385
386
        # Lazy initialized fields
        self._request_tracker: RequestTracker

387
    @classmethod
yhu422's avatar
yhu422 committed
388
389
390
391
392
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
393
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
yhu422's avatar
yhu422 committed
394
    ) -> "AsyncLLMEngine":
395
396
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
397
        engine_config = engine_args.create_engine_config()
398
399
400
401
402

        if engine_args.engine_use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

403
404
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
405

406
        if engine_config.device_config.device_type == "neuron":
407
408
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
409
410
411
        elif engine_config.device_config.device_type == "tpu":
            from vllm.executor.tpu_executor import TPUExecutorAsync
            executor_class = TPUExecutorAsync
412
        elif engine_config.device_config.device_type == "cpu":
413
414
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with the CPU backend.")
415
416
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
417
418
419
420
421
422
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
423
424
425
426
427
428
429
430
431
432
433
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
434
        elif distributed_executor_backend == "ray":
435
            initialize_ray_cluster(engine_config.parallel_config)
436
437
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
438
439
440
441
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
442
443
444
445
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
446
        engine = cls(
447
            distributed_executor_backend == "ray",
yhu422's avatar
yhu422 committed
448
            engine_args.engine_use_ray,
449
450
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
451
452
453
454
455
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            max_log_len=engine_args.max_log_len,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
456
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
457
        )
458
459
        return engine

460
461
    @property
    def is_running(self) -> bool:
462
        return (self.background_loop is not None
463
                and self._background_loop_unshielded is not None
464
465
466
467
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
468
469
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
470
471
472
473
474
475
476
477
478
479
480
481
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
482

483
484
    async def get_tokenizer(self) -> "PreTrainedTokenizer":
        if self.engine_use_ray:
485
            return await self.engine.get_tokenizer.remote()  # type: ignore
486
487
        else:
            return self.engine.get_tokenizer()
488

489
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
490
        """Start the background loop."""
491
492
493
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
494
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
495
            raise RuntimeError("Background loop is already running.")
496
497
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
498
499
500
501

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
502
            partial(_log_task_completion, error_callback=self._error_callback))
503
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
504
505
506

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
507
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
508
            engine_class = self._engine_class
509
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
510
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
511
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
512
513
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
514
515
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
516
517
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
518
519
520
521
522
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
523
524
        return engine_class(*args, **kwargs)

525
    async def engine_step(self, virtual_engine: int) -> bool:
526
527
528
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
529
530

        new_requests, finished_requests = (
531
            self._request_tracker.get_new_and_finished_requests())
532
533
534
535

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
536
537
            try:
                if self.engine_use_ray:
538
539
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
540
541
542
543
544
545
546
547
548
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
549
550
551
552

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
553
        if self.engine_use_ray:
554
            request_outputs = await self.engine.step.remote()  # type: ignore
555
        else:
556
            request_outputs = await self.engine.step_async(virtual_engine)
557

Antoni Baum's avatar
Antoni Baum committed
558
        # Put the outputs into the corresponding streams.
559
        finished = True
560
        for request_output in request_outputs:
561
            self._request_tracker.process_request_output(
562
                request_output, verbose=self.log_requests)
563
            finished = finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
564

565
        return not finished
566

Antoni Baum's avatar
Antoni Baum committed
567
568
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
569
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
570
571
572
573
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
574
575
576
577
578
579
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
580
        while True:
581
            if not any(has_requests_in_progress):
582
                logger.debug("Waiting for new requests...")
583
584
585
586
587
588
589
590
591
592
593
594
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
595
                await self._request_tracker.wait_for_new_requests()
596
                logger.debug("Got new requests!")
597
598
599
600
601
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
602
603
604
605

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
606
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
633
634
635
636
637
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
638
639
640
641
642
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
643
        inputs: PromptInputs,
644
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
645
        arrival_time: Optional[float] = None,
646
        lora_request: Optional[LoRARequest] = None,
647
        trace_headers: Optional[Dict[str, str]] = None,
648
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
Antoni Baum's avatar
Antoni Baum committed
649
650
    ) -> AsyncStream:
        if self.log_requests:
651
652
653
654
655
656
657
658
659
            if isinstance(inputs, str):
                shortened_prompt = inputs
                shortened_token_ids = None
            else:
                shortened_prompt = inputs.get("prompt")
                shortened_token_ids = inputs.get("prompt_token_ids")

            max_log_len = self.max_log_len
            if max_log_len is not None:
660
                if shortened_prompt is not None:
661
                    shortened_prompt = shortened_prompt[:max_log_len]
662
                if shortened_token_ids is not None:
663
664
                    shortened_token_ids = shortened_token_ids[:max_log_len]

665
666
            logger.info(
                "Received request %s: prompt: %r, "
667
668
669
                "params: %s, prompt_token_ids: %s, "
                "lora_request: %s.", request_id, shortened_prompt, params,
                shortened_token_ids, lora_request)
Antoni Baum's avatar
Antoni Baum committed
670

671
        if not self.is_running:
672
673
674
675
676
677
678
679
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
680

681
682
        if arrival_time is None:
            arrival_time = time.time()
683

684
        stream = self._request_tracker.add_request(
685
            request_id,
686
            inputs=inputs,
687
            params=params,
688
            arrival_time=arrival_time,
689
            lora_request=lora_request,
690
            trace_headers=trace_headers,
691
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
692
693

        return stream
694

695
    async def generate(
696
        self,
697
        inputs: PromptInputs,
698
699
        sampling_params: SamplingParams,
        request_id: str,
700
        lora_request: Optional[LoRARequest] = None,
701
        trace_headers: Optional[Dict[str, str]] = None,
702
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
703
    ) -> AsyncIterator[RequestOutput]:
704
705
706
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
707
708
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
709
710

        Args:
711
712
713
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
714
715
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
716
            lora_request: LoRA request to use for generation, if any.
717
            trace_headers: OpenTelemetry trace headers.
718
719
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
720
721

        Yields:
722
723
            The output `RequestOutput` objects from the LLMEngine
            for the request.
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
767
        """
768
        async for output in self._process_request(
769
                request_id,
770
                inputs,
771
                sampling_params,
772
                lora_request=lora_request,
773
                trace_headers=trace_headers,
774
                prompt_adapter_request=prompt_adapter_request,
775
        ):
776
            yield LLMEngine.validate_output(output, RequestOutput)
777
778
779

    async def encode(
        self,
780
        inputs: PromptInputs,
781
782
783
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
784
        trace_headers: Optional[Dict[str, str]] = None,
785
786
787
788
789
790
791
792
    ) -> AsyncIterator[EmbeddingRequestOutput]:
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
793
794
795
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
796
797
798
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
799
            trace_headers: OpenTelemetry trace headers.
800
801

        Yields:
802
            The output `EmbeddingRequestOutput` objects from the LLMEngine
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
846
        async for output in self._process_request(
847
                request_id,
848
                inputs,
849
                pooling_params,
850
                lora_request=lora_request,
851
                trace_headers=trace_headers,
852
        ):
853
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
854

855
    async def _process_request(
856
857
        self,
        request_id: str,
858
        inputs: PromptInputs,
859
        params: Union[SamplingParams, PoolingParams],
860
        *,
861
        lora_request: Optional[LoRARequest] = None,
862
        trace_headers: Optional[Dict[str, str]] = None,
863
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
864
865
866
867
868
869
870
    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
        """Common logic to process requests with SamplingParams or
        PoolingParams."""
        arrival_time = time.time()

        stream = await self.add_request(
            request_id,
871
            inputs,
872
873
874
            params,
            arrival_time=arrival_time,
            lora_request=lora_request,
875
            trace_headers=trace_headers,
876
            prompt_adapter_request=prompt_adapter_request,
877
        )
878

879
        try:
Antoni Baum's avatar
Antoni Baum committed
880
881
            async for request_output in stream:
                yield request_output
882
        except (Exception, asyncio.CancelledError) as e:
Antoni Baum's avatar
Antoni Baum committed
883
884
            self._abort(request_id)
            raise e
885

Antoni Baum's avatar
Antoni Baum committed
886
887
    async def abort(self, request_id: str) -> None:
        """Abort a request.
888

Antoni Baum's avatar
Antoni Baum committed
889
890
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
891

Antoni Baum's avatar
Antoni Baum committed
892
893
894
        Args:
            request_id: The unique id of the request.
        """
895
896
897
898
899
900
901
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
902
        return self._abort(request_id)
903

Antoni Baum's avatar
Antoni Baum committed
904
    def _abort(self, request_id: str) -> None:
905
906
907
908
909
910
911
912
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
913
914
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
915

916
917
918
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
919
            return await self.engine.get_model_config.remote()  # type: ignore
920
921
922
        else:
            return self.engine.get_model_config()

923
924
925
926
927
928
929
930
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

931
932
933
934
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
935
        if self.engine_use_ray:
936
937
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
938
939
        else:
            self.engine.do_log_stats()
940

941
    async def check_health(self) -> None:
942
943
944
945
946
947
948
949
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
950
                await self.engine.check_health.remote()  # type: ignore
951
952
953
954
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
955
        logger.debug("Health check took %fs", time.perf_counter() - t)
956
957
958
959
960
961
962

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.add_logger.remote(  # type: ignore
                    logger_name=logger_name, logger=logger))
        else:
            self.engine.add_logger(logger_name=logger_name, logger=logger)

    def remove_logger(self, logger_name: str) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.remove_logger.remote(  # type: ignore
                    logger_name=logger_name))
        else:
            self.engine.remove_logger(logger_name=logger_name)