async_llm_engine.py 40.5 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
5
                    Optional, Set, Tuple, Type, Union)
6

7
8
from transformers import PreTrainedTokenizer

9
import vllm.envs as envs
10
11
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
12
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.engine.async_timeout import asyncio_timeout
Woosuk Kwon's avatar
Woosuk Kwon committed
15
from vllm.engine.llm_engine import LLMEngine
16
from vllm.engine.metrics import StatLoggerBase
17
from vllm.executor.executor_base import ExecutorAsyncBase
18
from vllm.executor.ray_utils import initialize_ray_cluster, ray
19
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.logger import init_logger
21
from vllm.lora.request import LoRARequest
22
23
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
24
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
25
from vllm.sampling_params import SamplingParams
26
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
27
from vllm.usage.usage_lib import UsageContext
28
29

logger = init_logger(__name__)
30
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
31

Antoni Baum's avatar
Antoni Baum committed
32

33
34
35
36
class AsyncEngineDeadError(RuntimeError):
    pass


37
38
39
40
41
42
43
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
44
45

    exception = None
46
    try:
47
48
49
50
51
52
53
54
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
55
56
57
58
59
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
60
            "Task finished unexpectedly. This should never happen! "
61
            "Please open an issue on Github. See stack trace above for the "
62
            "actual cause.") from e
63
64


65
66
67
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
68
class AsyncStream:
69
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
70
    that can be iterated over asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
71

72
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
73
        self.request_id = request_id
74
        self._cancel = cancel
75
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
76
77
        self._finished = False

78
79
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
80
81
82
83
        if self._finished:
            return
        self._queue.put_nowait(item)

84
85
86
87
88
    def finish(self, cancelled: bool = False) -> None:
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
                asyncio.CancelledError if cancelled else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
89
90
91
92
93

    @property
    def finished(self) -> bool:
        return self._finished

94
95
96
97
98
99
100
101
102
103
104
105
106
107
    async def generator(
        self
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
        try:
            while not self._finished:
                result = await self._queue.get()
                if isinstance(result, Exception):
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
108
109


110
111
112
113
114
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
115
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
116
117
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
118
        self.new_requests_event = asyncio.Event()
119
120
121
122

    def __contains__(self, item):
        return item in self._request_streams

123
124
    def __len__(self) -> int:
        return len(self._request_streams)
125
126
127
128
129
130
131
132

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
133
            self.abort_request(request_id)
134
        else:
135
136
137
            # NB: list() used here because self.abort_request pops the stream
            # out of self._request_streams, so we can't iterate on it directly
            for rid, stream in list(self._request_streams.items()):
138
                stream.put(exc)
139
                self.abort_request(rid)
140
141

    def process_request_output(self,
142
143
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
144
145
146
147
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
148
        finished = request_output.finished
149

150
151
152
153
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
154
155
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
156
        if stream is not None:
157
            stream.put(request_output)
158
159
160
161
162
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
163

164
165
166
167
168
169
170
171
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
172
            logger.info("Finished request %s.", request_id)
173
174
        self.abort_request(request_id)

175
176
177
178
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
179
180
181
182
183
184
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

185
186
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
187
188
189
190
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
191
192
193

        self.new_requests_event.set()

194
195
196
        if verbose:
            logger.info("Added request %s.", request_id)

197
198
        return stream

199
200
201
202
203
    def abort_request(self,
                      request_id: str,
                      *,
                      cancelled: bool = False,
                      verbose: bool = False) -> None:
204
205
        """Abort a request during next background loop iteration."""
        if verbose:
206
            logger.info("Aborted request %s.", request_id)
207

208
        self._aborted_requests.put_nowait(request_id)
209

210
211
212
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
            stream.finish(cancelled=cancelled)
213

214
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
215
216
        """Get the new requests and finished requests to be
        sent to the engine."""
217
        new_requests: List[Dict] = []
218
219
        finished_requests: Set[str] = set()

220
221
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
222
223
224
225
226
227
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
228
                stream.finish(cancelled=True)
229
230
231
232
233
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
234

235
    async def wait_for_new_requests(self):
236
237
238
239
240
241
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
242

Antoni Baum's avatar
Antoni Baum committed
243
244
245
246

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

247
    async def step_async(
248
249
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
250
251
252
253
254
255
256
257
258
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
259
260
        seq_group_metadata_list, scheduler_outputs = self.scheduler[
            virtual_engine].schedule()
Antoni Baum's avatar
Antoni Baum committed
261

262
263
        if not scheduler_outputs.is_empty():
            # Execute the model.
264
265
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
266
267
268
269
270
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
271
                virtual_engine=virtual_engine,
272
273
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
Mor Zusman's avatar
Mor Zusman committed
274
                finished_requests_ids=finished_requests_ids)
275
            output = await self.model_executor.execute_model_async(
276
                execute_model_req)
277
278
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
279

280
        request_outputs = self._process_model_outputs(
281
            output, scheduler_outputs.scheduled_seq_groups,
282
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
283

284
        # Log stats.
285
        self.do_log_stats(scheduler_outputs, output)
286

287
288
289
        # Tracing
        self.do_tracing(scheduler_outputs)

290
291
        return request_outputs

292
293
294
295
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

296
    async def process_model_inputs_async(
297
        self,
298
299
        request_id: str,
        inputs: PromptInputs,
300
        lora_request: Optional[LoRARequest] = None,
301
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
302
303
304
305
306
307
308
309
310
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = await tokenizer.encode_async(
311
                request_id=request_id,
312
                prompt=inputs["prompt"],
313
                lora_request=lora_request)
314
315
316
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

317
318
319
320
321
322
        if prompt_adapter_request:
            prompt_token_ids = [
                0
            ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
                prompt_token_ids

323
324
325
326
327
        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
                               prompt=inputs.get("prompt"),
                               multi_modal_data=inputs.get("multi_modal_data"))

        return self.input_processor(llm_inputs)
328
329

    async def add_request_async(
330
331
332
333
334
335
336
337
        self,
        request_id: str,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
338
339
340
341
342
343
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
344
345

        processed_inputs = await self.process_model_inputs_async(
346
347
348
349
            request_id=request_id,
            inputs=inputs,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
350
351

        self._add_processed_request(
352
            request_id=request_id,
353
354
355
356
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
357
            prompt_adapter_request=prompt_adapter_request,
358
            trace_headers=trace_headers,
359
        )
360

361
    async def check_health_async(self) -> None:
362
363
        if self.tokenizer:
            self.tokenizer.check_health()
364
        self.model_executor.check_health()
365

366

367
class AsyncLLMEngine:
368
    """An asynchronous wrapper for :class:`LLMEngine`.
369

370
371
372
373
374
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
375
376
377
378
379

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
380
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
381
382
            async frontend will be executed in a separate process as the
            model workers.
383
        log_requests: Whether to log the requests.
384
385
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
386
387
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
388
    """
389

Antoni Baum's avatar
Antoni Baum committed
390
391
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

392
393
394
395
396
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
397
                 start_engine_loop: bool = True,
398
                 **kwargs) -> None:
399
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
400
        self.engine_use_ray = engine_use_ray
401
        self.log_requests = log_requests
Antoni Baum's avatar
Antoni Baum committed
402
403
        self.engine = self._init_engine(*args, **kwargs)

404
        self.background_loop: Optional[asyncio.Future] = None
405
406
407
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
408
        self._background_loop_unshielded: Optional[asyncio.Task] = None
409
        self.start_engine_loop = start_engine_loop
410
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
411

412
413
414
        # Lazy initialized fields
        self._request_tracker: RequestTracker

415
    @classmethod
416
417
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
418
419
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
420
421
422
423
424
425
426
427
428
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
429
430
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
431
        elif engine_config.device_config.device_type == "tpu":
432
433
434
435
436
437
438
439
            if distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
                executor_class = RayTPUExecutorAsync
            else:
                assert distributed_executor_backend is None
                from vllm.executor.tpu_executor import TPUExecutorAsync
                executor_class = TPUExecutorAsync
440
441
442
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
443
444
445
446
447
448
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
449
450
451
452
453
454
455
456
457
458
459
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
460
        elif distributed_executor_backend == "ray":
461
            initialize_ray_cluster(engine_config.parallel_config)
462
463
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
464
465
466
467
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
468
469
470
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        if engine_args.engine_use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

        executor_class = cls._get_executor_cls(engine_config)

491
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
492
        engine = cls(
493
            executor_class.uses_ray,
yhu422's avatar
yhu422 committed
494
            engine_args.engine_use_ray,
495
496
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
497
498
499
500
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
501
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
502
        )
503
504
        return engine

505
506
    @property
    def is_running(self) -> bool:
507
        return (self.background_loop is not None
508
                and self._background_loop_unshielded is not None
509
510
511
512
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
513
514
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
515
516
517
518
519
520
521
522
523
524
525
526
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
527

528
529
530
531
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> "PreTrainedTokenizer":
532
        if self.engine_use_ray:
533
534
535
536
537
            return await self.engine.get_tokenizer.remote(  # type: ignore
                lora_request)

        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
538

539
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
540
        """Start the background loop."""
541
542
543
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
544
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
545
            raise RuntimeError("Background loop is already running.")
546
547
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
548
549
550
551

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
552
            partial(_log_task_completion, error_callback=self._error_callback))
553
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
554
555
556

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
557
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
558
            engine_class = self._engine_class
559
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
560
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
561
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
562
563
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
564
565
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
566
567
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
568
569
570
571
572
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
573
574
        return engine_class(*args, **kwargs)

575
    async def engine_step(self, virtual_engine: int) -> bool:
576
577
578
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
579

580
581
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
582
583
584
585

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
586
587
            try:
                if self.engine_use_ray:
588
589
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
590
591
592
593
594
595
596
597
598
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
599

600
601
        if aborted_requests:
            await self._engine_abort(aborted_requests)
602

Zhuohan Li's avatar
Zhuohan Li committed
603
        if self.engine_use_ray:
604
            request_outputs = await self.engine.step.remote()  # type: ignore
605
        else:
606
            request_outputs = await self.engine.step_async(virtual_engine)
607

Antoni Baum's avatar
Antoni Baum committed
608
        # Put the outputs into the corresponding streams.
609
        finished = True
610
        for request_output in request_outputs:
611
            self._request_tracker.process_request_output(
612
                request_output, verbose=self.log_requests)
613
            finished = finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
614

615
        return not finished
616

Antoni Baum's avatar
Antoni Baum committed
617
618
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
619
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
620
621
622
623
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
624
625
626
627
628
629
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
630
        while True:
631
            if not any(has_requests_in_progress):
632
                logger.debug("Waiting for new requests...")
633
634
635
636
637
638
639
640
641
642
643
644
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
645
                await self._request_tracker.wait_for_new_requests()
646
                logger.debug("Got new requests!")
647
648
649
650
651
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
652
653
654
655

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
656
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
683
684
685
686
687
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
688
689
            await asyncio.sleep(0)

690
691
    # This method does not need to be async, but kept that way
    # for backwards compatibility.
Antoni Baum's avatar
Antoni Baum committed
692
693
694
    async def add_request(
        self,
        request_id: str,
695
        inputs: PromptInputs,
696
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
697
        arrival_time: Optional[float] = None,
698
        lora_request: Optional[LoRARequest] = None,
699
        trace_headers: Optional[Mapping[str, str]] = None,
700
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
701
    ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
702
        if not self.is_running:
703
704
705
706
707
708
709
710
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
711

712
        stream = self._request_tracker.add_request(
713
            request_id,
714
            verbose=self.log_requests,
715
            inputs=inputs,
716
            params=params,
717
            arrival_time=arrival_time or time.time(),
718
            lora_request=lora_request,
719
            trace_headers=trace_headers,
720
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
721

722
        return stream.generator()
723

724
    async def generate(
725
        self,
726
        inputs: PromptInputs,
727
728
        sampling_params: SamplingParams,
        request_id: str,
729
        lora_request: Optional[LoRARequest] = None,
730
        trace_headers: Optional[Mapping[str, str]] = None,
731
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
732
    ) -> AsyncGenerator[RequestOutput, None]:
733
734
735
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
736
737
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
738
739

        Args:
740
741
742
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
743
744
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
745
            lora_request: LoRA request to use for generation, if any.
746
            trace_headers: OpenTelemetry trace headers.
747
748
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
749
750

        Yields:
751
752
            The output `RequestOutput` objects from the LLMEngine
            for the request.
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
796
        """
797
        async for output in await self.add_request(
798
                request_id,
799
                inputs,
800
                sampling_params,
801
                lora_request=lora_request,
802
                trace_headers=trace_headers,
803
                prompt_adapter_request=prompt_adapter_request,
804
        ):
805
            yield LLMEngine.validate_output(output, RequestOutput)
806
807
808

    async def encode(
        self,
809
        inputs: PromptInputs,
810
811
812
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
813
        trace_headers: Optional[Mapping[str, str]] = None,
814
    ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
815
816
817
818
819
820
821
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
822
823
824
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
825
826
827
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
828
            trace_headers: OpenTelemetry trace headers.
829
830

        Yields:
831
            The output `EmbeddingRequestOutput` objects from the LLMEngine
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
875
        async for output in await self.add_request(
876
                request_id,
877
                inputs,
878
                pooling_params,
879
                lora_request=lora_request,
880
                trace_headers=trace_headers,
881
        ):
882
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
883

Antoni Baum's avatar
Antoni Baum committed
884
885
    async def abort(self, request_id: str) -> None:
        """Abort a request.
886

Antoni Baum's avatar
Antoni Baum committed
887
888
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
889

Antoni Baum's avatar
Antoni Baum committed
890
891
892
        Args:
            request_id: The unique id of the request.
        """
893
894
895
896
897
898
899
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
900
        return self._abort(request_id)
901

Antoni Baum's avatar
Antoni Baum committed
902
    def _abort(self, request_id: str) -> None:
903
904
905
906
907
908
909
910
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
911
        self._request_tracker.abort_request(request_id,
912
                                            cancelled=True,
913
                                            verbose=self.log_requests)
914

915
916
917
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
918
            return await self.engine.get_model_config.remote()  # type: ignore
919
920
921
        else:
            return self.engine.get_model_config()

922
923
924
925
926
927
928
929
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_parallel_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_parallel_config()

930
931
932
933
934
935
936
937
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_scheduler_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_scheduler_config()

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_lora_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_lora_config()

954
955
956
957
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
958
        if self.engine_use_ray:
959
960
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
961
962
        else:
            self.engine.do_log_stats()
963

964
    async def check_health(self) -> None:
965
966
967
968
969
970
971
972
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
973
                await self.engine.check_health.remote()  # type: ignore
974
975
976
977
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
978
        logger.debug("Health check took %fs", time.perf_counter() - t)
979
980
981
982
983
984
985

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.add_logger.remote(  # type: ignore
                    logger_name=logger_name, logger=logger))
        else:
            self.engine.add_logger(logger_name=logger_name, logger=logger)

    def remove_logger(self, logger_name: str) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.remove_logger.remote(  # type: ignore
                    logger_name=logger_name))
        else:
            self.engine.remove_logger(logger_name=logger_name)