async_llm_engine.py 40.2 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
5
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
                    Set, Tuple, Type, Union)
6

7
8
from transformers import PreTrainedTokenizer

9
import vllm.envs as envs
10
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
11
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.engine.arg_utils import AsyncEngineArgs
13
from vllm.engine.async_timeout import asyncio_timeout
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.engine.llm_engine import LLMEngine
15
from vllm.engine.metrics import StatLoggerBase
16
from vllm.executor.executor_base import ExecutorAsyncBase
17
from vllm.executor.ray_utils import initialize_ray_cluster, ray
18
from vllm.inputs import LLMInputs, PromptInputs
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.logger import init_logger
20
from vllm.lora.request import LoRARequest
21
22
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
23
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.sampling_params import SamplingParams
25
from vllm.sequence import ExecuteModelRequest, SamplerOutput
yhu422's avatar
yhu422 committed
26
from vllm.usage.usage_lib import UsageContext
27
28

logger = init_logger(__name__)
29
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
30

Antoni Baum's avatar
Antoni Baum committed
31

32
33
34
35
class AsyncEngineDeadError(RuntimeError):
    pass


36
37
38
39
40
41
42
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
43
44

    exception = None
45
    try:
46
47
48
49
50
51
52
53
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
54
55
56
57
58
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
59
60
61
            "Task finished unexpectedly. This should never happen! "
            "Please open an issue on Github. See stack trace above for the"
            "actual cause.") from e
62
63


Antoni Baum's avatar
Antoni Baum committed
64
class AsyncStream:
65
66
    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
    that can be iterated over asynchronously."""
Antoni Baum's avatar
Antoni Baum committed
67
68
69

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
70
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
71
72
        self._finished = False

73
74
    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
                              Exception]) -> None:
Antoni Baum's avatar
Antoni Baum committed
75
76
77
78
79
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
80
        self._queue.put_nowait(StopAsyncIteration())
Antoni Baum's avatar
Antoni Baum committed
81
82
83
84
85
86
87
88
89
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

90
    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
91
        result = await self._queue.get()
92
        if isinstance(result, Exception):
93
            raise result
Antoni Baum's avatar
Antoni Baum committed
94
95
96
        return result


97
98
99
100
101
102
103
104
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
105
        self.new_requests_event = asyncio.Event()
106
107
108
109

    def __contains__(self, item):
        return item in self._request_streams

110
111
    def __len__(self) -> int:
        return len(self._request_streams)
112
113
114
115
116
117
118
119

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
120
            self.abort_request(request_id)
121
        else:
122
            for rid, stream in self._request_streams.items():
123
                stream.put(exc)
124
                self.abort_request(rid)
125
126

    def process_request_output(self,
127
128
                               request_output: Union[RequestOutput,
                                                     EmbeddingRequestOutput],
129
130
131
132
133
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

134
135
136
137
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
        if (stream := self._request_streams.get(request_id)) is not None:
            stream.put(request_output)
138
139
        if request_output.finished:
            if verbose:
140
                logger.info("Finished request %s.", request_id)
141
142
            self.abort_request(request_id)

143
144
145
146
147
148
149
150
    def process_exception(self,
                          request_id: str,
                          exception: Exception,
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        self._request_streams[request_id].put(exception)
        if verbose:
151
            logger.info("Finished request %s.", request_id)
152
153
        self.abort_request(request_id)

154
155
156
157
158
159
160
161
162
163
164
165
    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
166
167
168

        self.new_requests_event.set()

169
170
171
172
173
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
174
            logger.info("Aborted request %s.", request_id)
175
176
177
178
179
180
181
182
183
184

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

185
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
186
187
        """Get the new requests and finished requests to be
        sent to the engine."""
188
        new_requests: List[Dict] = []
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
206

207
    async def wait_for_new_requests(self):
208
209
210
211
212
213
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
214

Antoni Baum's avatar
Antoni Baum committed
215
216
217
218

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

219
    async def step_async(
220
221
        self, virtual_engine: int
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Antoni Baum's avatar
Antoni Baum committed
222
223
224
225
226
227
228
229
230
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
231
232
        seq_group_metadata_list, scheduler_outputs = self.scheduler[
            virtual_engine].schedule()
Antoni Baum's avatar
Antoni Baum committed
233

234
235
        if not scheduler_outputs.is_empty():
            # Execute the model.
236
237
            finished_requests_ids = self.scheduler[
                virtual_engine].get_and_reset_finished_requests_ids()
238
239
240
241
242
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
243
                virtual_engine=virtual_engine,
244
245
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
Mor Zusman's avatar
Mor Zusman committed
246
                finished_requests_ids=finished_requests_ids)
247
            output = await self.model_executor.execute_model_async(
248
                execute_model_req)
249
250
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
251

252
        request_outputs = self._process_model_outputs(
253
            output, scheduler_outputs.scheduled_seq_groups,
254
            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
Antoni Baum's avatar
Antoni Baum committed
255

256
        # Log stats.
257
        self.do_log_stats(scheduler_outputs, output)
258

259
260
261
        # Tracing
        self.do_tracing(scheduler_outputs)

262
263
        return request_outputs

264
265
266
267
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

268
    async def process_model_inputs_async(
269
        self,
270
271
        request_id: str,
        inputs: PromptInputs,
272
        lora_request: Optional[LoRARequest] = None,
273
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
274
275
276
277
278
279
280
281
282
    ) -> LLMInputs:
        if isinstance(inputs, str):
            inputs = {"prompt": inputs}

        if "prompt_token_ids" not in inputs:
            tokenizer = self.get_tokenizer_group("prompts must be None if "
                                                 "skip_tokenizer_init is True")

            prompt_token_ids = await tokenizer.encode_async(
283
                request_id=request_id,
284
                prompt=inputs["prompt"],
285
                lora_request=lora_request)
286
287
288
        else:
            prompt_token_ids = inputs["prompt_token_ids"]

289
290
291
292
293
294
        if prompt_adapter_request:
            prompt_token_ids = [
                0
            ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
                prompt_token_ids

295
296
297
298
299
        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
                               prompt=inputs.get("prompt"),
                               multi_modal_data=inputs.get("multi_modal_data"))

        return self.input_processor(llm_inputs)
300
301

    async def add_request_async(
302
303
304
305
306
307
308
309
            self,
            request_id: str,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            arrival_time: Optional[float] = None,
            lora_request: Optional[LoRARequest] = None,
            trace_headers: Optional[Dict[str, str]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
310
311
312
313
314
315
    ) -> None:
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
        if arrival_time is None:
            arrival_time = time.time()
316
317

        processed_inputs = await self.process_model_inputs_async(
318
319
320
321
            request_id=request_id,
            inputs=inputs,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
322
323

        self._add_processed_request(
324
            request_id=request_id,
325
326
327
328
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
329
            prompt_adapter_request=prompt_adapter_request,
330
            trace_headers=trace_headers,
331
        )
332

333
    async def check_health_async(self) -> None:
334
335
        if self.tokenizer:
            self.tokenizer.check_health()
336
        self.model_executor.check_health()
337

338

339
class AsyncLLMEngine:
340
    """An asynchronous wrapper for :class:`LLMEngine`.
341

342
343
344
345
346
    This class is used to wrap the :class:`LLMEngine` class to make it
    asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The :class:`LLMEngine` is kicked by the
    generate method when there are requests in the waiting queue. The generate
    method yields the outputs from the :class:`LLMEngine` to the caller.
347
348
349
350
351

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
352
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
353
354
            async frontend will be executed in a separate process as the
            model workers.
355
        log_requests: Whether to log the requests.
zspo's avatar
zspo committed
356
357
        max_log_len: Maximum number of prompt characters or prompt ID numbers
            being printed in log.
358
359
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
360
361
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
362
    """
363

Antoni Baum's avatar
Antoni Baum committed
364
365
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

366
367
368
369
370
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
371
                 max_log_len: Optional[int] = None,
372
                 start_engine_loop: bool = True,
373
                 **kwargs) -> None:
374
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
375
        self.engine_use_ray = engine_use_ray
376
        self.log_requests = log_requests
377
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
378
379
        self.engine = self._init_engine(*args, **kwargs)

380
        self.background_loop: Optional[asyncio.Future] = None
381
382
383
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
384
        self._background_loop_unshielded: Optional[asyncio.Task] = None
385
        self.start_engine_loop = start_engine_loop
386
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
387

388
389
390
        # Lazy initialized fields
        self._request_tracker: RequestTracker

391
    @classmethod
392
393
    def _get_executor_cls(
            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
394
395
        distributed_executor_backend = (
            engine_config.parallel_config.distributed_executor_backend)
396
397
398
399
400
401
402
403
404
        if isinstance(distributed_executor_backend, type):
            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
                raise TypeError(
                    "distributed_executor_backend must be a subclass of "
                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
            if distributed_executor_backend.uses_ray:  # type: ignore
                initialize_ray_cluster(engine_config.parallel_config)
            executor_class = distributed_executor_backend
        elif engine_config.device_config.device_type == "neuron":
405
406
            from vllm.executor.neuron_executor import NeuronExecutorAsync
            executor_class = NeuronExecutorAsync
407
408
409
        elif engine_config.device_config.device_type == "tpu":
            from vllm.executor.tpu_executor import TPUExecutorAsync
            executor_class = TPUExecutorAsync
410
        elif engine_config.device_config.device_type == "cpu":
411
412
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with the CPU backend.")
413
414
            from vllm.executor.cpu_executor import CPUExecutorAsync
            executor_class = CPUExecutorAsync
415
416
417
418
419
420
        elif engine_config.device_config.device_type == "openvino":
            assert distributed_executor_backend is None, (
                "Distributed execution is not supported with "
                "the OpenVINO backend.")
            from vllm.executor.openvino_executor import OpenVINOExecutorAsync
            executor_class = OpenVINOExecutorAsync
421
422
423
424
425
426
427
428
429
430
431
        elif engine_config.device_config.device_type == "xpu":
            if distributed_executor_backend is None:
                from vllm.executor.xpu_executor import XPUExecutorAsync
                executor_class = XPUExecutorAsync
            elif distributed_executor_backend == "ray":
                initialize_ray_cluster(engine_config.parallel_config)
                from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
                executor_class = RayXPUExecutorAsync
            else:
                raise RuntimeError(
                    "Not supported distributed execution model on XPU device.")
432
        elif distributed_executor_backend == "ray":
433
            initialize_ray_cluster(engine_config.parallel_config)
434
435
            from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
            executor_class = RayGPUExecutorAsync
436
437
438
439
        elif distributed_executor_backend == "mp":
            from vllm.executor.multiproc_gpu_executor import (
                MultiprocessingGPUExecutorAsync)
            executor_class = MultiprocessingGPUExecutorAsync
440
441
442
        else:
            from vllm.executor.gpu_executor import GPUExecutorAsync
            executor_class = GPUExecutorAsync
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        return executor_class

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        if engine_args.engine_use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

        executor_class = cls._get_executor_cls(engine_config)

463
        # Create the async LLM engine.
yhu422's avatar
yhu422 committed
464
        engine = cls(
465
            executor_class.uses_ray,
yhu422's avatar
yhu422 committed
466
            engine_args.engine_use_ray,
467
468
            **engine_config.to_dict(),
            executor_class=executor_class,
yhu422's avatar
yhu422 committed
469
470
471
472
473
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            max_log_len=engine_args.max_log_len,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
474
            stat_loggers=stat_loggers,
yhu422's avatar
yhu422 committed
475
        )
476
477
        return engine

478
479
    @property
    def is_running(self) -> bool:
480
        return (self.background_loop is not None
481
                and self._background_loop_unshielded is not None
482
483
484
485
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
486
487
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
488
489
490
491
492
493
494
495
496
497
498
499
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
500

501
502
503
504
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> "PreTrainedTokenizer":
505
        if self.engine_use_ray:
506
507
508
509
510
            return await self.engine.get_tokenizer.remote(  # type: ignore
                lora_request)

        return await (self.engine.get_tokenizer_group().
                      get_lora_tokenizer_async(lora_request))
511

512
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
513
        """Start the background loop."""
514
515
516
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
517
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
518
            raise RuntimeError("Background loop is already running.")
519
520
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
521
522
523
524

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
525
            partial(_log_task_completion, error_callback=self._error_callback))
526
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
527
528
529

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
530
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
531
            engine_class = self._engine_class
532
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
533
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
534
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
535
536
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
537
538
            cache_config = kwargs["cache_config"]
            parallel_config = kwargs["parallel_config"]
539
540
            if (parallel_config.tensor_parallel_size == 1
                    and parallel_config.pipeline_parallel_size == 1):
Woosuk Kwon's avatar
Woosuk Kwon committed
541
542
543
544
545
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
546
547
        return engine_class(*args, **kwargs)

548
    async def engine_step(self, virtual_engine: int) -> bool:
549
550
551
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
552
553

        new_requests, finished_requests = (
554
            self._request_tracker.get_new_and_finished_requests())
555
556
557
558

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
559
560
            try:
                if self.engine_use_ray:
561
562
                    await self.engine.add_request.remote(  # type: ignore
                        **new_request)
563
564
565
566
567
568
569
570
571
                else:
                    await self.engine.add_request_async(**new_request)
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
572
573
574
575

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
576
        if self.engine_use_ray:
577
            request_outputs = await self.engine.step.remote()  # type: ignore
578
        else:
579
            request_outputs = await self.engine.step_async(virtual_engine)
580

Antoni Baum's avatar
Antoni Baum committed
581
        # Put the outputs into the corresponding streams.
582
        finished = True
583
        for request_output in request_outputs:
584
            self._request_tracker.process_request_output(
585
                request_output, verbose=self.log_requests)
586
            finished = finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
587

588
        return not finished
589

Antoni Baum's avatar
Antoni Baum committed
590
591
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
592
            await self.engine.abort_request.remote(request_ids)  # type: ignore
Antoni Baum's avatar
Antoni Baum committed
593
594
595
596
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
597
598
599
600
601
602
        if self.engine_use_ray:
            pipeline_parallel_size = 1  # type: ignore
        else:
            pipeline_parallel_size = \
                self.engine.parallel_config.pipeline_parallel_size
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
603
        while True:
604
            if not any(has_requests_in_progress):
605
                logger.debug("Waiting for new requests...")
606
607
608
609
610
611
612
613
614
615
616
617
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
                if self.engine_use_ray:
                    await (self.engine.stop_remote_worker_execution_loop.
                           remote()  # type: ignore
                           )
                else:
                    await self.engine.stop_remote_worker_execution_loop_async()
618
                await self._request_tracker.wait_for_new_requests()
619
                logger.debug("Got new requests!")
620
621
622
623
624
                requests_in_progress = [
                    asyncio.create_task(self.engine_step(ve))
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
625
626
627
628

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
629
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
                    if self.engine_use_ray:
                        has_unfinished_requests = (
                            await (self.engine.
                                   has_unfinished_requests_for_virtual_engine.
                                   remote(  # type: ignore
                                       virtual_engine)))
                    else:
                        has_unfinished_requests = (
                            self.engine.
                            has_unfinished_requests_for_virtual_engine(
                                virtual_engine))
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
                                self.engine_step(virtual_engine)))
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
656
657
658
659
660
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
                self.set_errored(exc)
                raise
Antoni Baum's avatar
Antoni Baum committed
661
662
663
664
665
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
666
        inputs: PromptInputs,
667
        params: Union[SamplingParams, PoolingParams],
Antoni Baum's avatar
Antoni Baum committed
668
        arrival_time: Optional[float] = None,
669
        lora_request: Optional[LoRARequest] = None,
670
        trace_headers: Optional[Dict[str, str]] = None,
671
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
Antoni Baum's avatar
Antoni Baum committed
672
673
    ) -> AsyncStream:
        if self.log_requests:
674
675
676
677
678
679
680
681
682
            if isinstance(inputs, str):
                shortened_prompt = inputs
                shortened_token_ids = None
            else:
                shortened_prompt = inputs.get("prompt")
                shortened_token_ids = inputs.get("prompt_token_ids")

            max_log_len = self.max_log_len
            if max_log_len is not None:
683
                if shortened_prompt is not None:
684
                    shortened_prompt = shortened_prompt[:max_log_len]
685
                if shortened_token_ids is not None:
686
687
                    shortened_token_ids = shortened_token_ids[:max_log_len]

688
689
            logger.info(
                "Received request %s: prompt: %r, "
690
691
692
                "params: %s, prompt_token_ids: %s, "
                "lora_request: %s.", request_id, shortened_prompt, params,
                shortened_token_ids, lora_request)
Antoni Baum's avatar
Antoni Baum committed
693

694
        if not self.is_running:
695
696
697
698
699
700
701
702
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
703

704
705
        if arrival_time is None:
            arrival_time = time.time()
706

707
        stream = self._request_tracker.add_request(
708
            request_id,
709
            inputs=inputs,
710
            params=params,
711
            arrival_time=arrival_time,
712
            lora_request=lora_request,
713
            trace_headers=trace_headers,
714
            prompt_adapter_request=prompt_adapter_request)
Antoni Baum's avatar
Antoni Baum committed
715
716

        return stream
717

718
    async def generate(
719
        self,
720
        inputs: PromptInputs,
721
722
        sampling_params: SamplingParams,
        request_id: str,
723
        lora_request: Optional[LoRARequest] = None,
724
        trace_headers: Optional[Dict[str, str]] = None,
725
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
726
    ) -> AsyncIterator[RequestOutput]:
727
728
729
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
730
731
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
732
733

        Args:
734
735
736
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
737
738
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
739
            lora_request: LoRA request to use for generation, if any.
740
            trace_headers: OpenTelemetry trace headers.
741
742
            prompt_adapter_request: Prompt Adapter request to use 
                                            for generation, if any.
743
744

        Yields:
745
746
            The output `RequestOutput` objects from the LLMEngine
            for the request.
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
790
        """
791
        async for output in self._process_request(
792
                request_id,
793
                inputs,
794
                sampling_params,
795
                lora_request=lora_request,
796
                trace_headers=trace_headers,
797
                prompt_adapter_request=prompt_adapter_request,
798
        ):
799
            yield LLMEngine.validate_output(output, RequestOutput)
800
801
802

    async def encode(
        self,
803
        inputs: PromptInputs,
804
805
806
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
807
        trace_headers: Optional[Dict[str, str]] = None,
808
809
810
811
812
813
814
815
    ) -> AsyncIterator[EmbeddingRequestOutput]:
        """Generate outputs for a request from an embedding model.

        Generate outputs for a request. This method is a coroutine. It adds the
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.

        Args:
816
817
818
            inputs: The inputs to the LLM. See
                :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
819
820
821
            pooling_params: The pooling parameters of the request.
            request_id: The unique id of the request.
            lora_request: LoRA request to use for generation, if any.
822
            trace_headers: OpenTelemetry trace headers.
823
824

        Yields:
825
            The output `EmbeddingRequestOutput` objects from the LLMEngine
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
            for the request.

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "input": "What is LLM?",
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.encode(
            >>>    example_input["input"],
            >>>    PoolingParams(),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
        """
869
        async for output in self._process_request(
870
                request_id,
871
                inputs,
872
                pooling_params,
873
                lora_request=lora_request,
874
                trace_headers=trace_headers,
875
        ):
876
            yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
877

878
    async def _process_request(
879
880
        self,
        request_id: str,
881
        inputs: PromptInputs,
882
        params: Union[SamplingParams, PoolingParams],
883
        *,
884
        lora_request: Optional[LoRARequest] = None,
885
        trace_headers: Optional[Dict[str, str]] = None,
886
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
887
888
889
890
891
892
893
    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
        """Common logic to process requests with SamplingParams or
        PoolingParams."""
        arrival_time = time.time()

        stream = await self.add_request(
            request_id,
894
            inputs,
895
896
897
            params,
            arrival_time=arrival_time,
            lora_request=lora_request,
898
            trace_headers=trace_headers,
899
            prompt_adapter_request=prompt_adapter_request,
900
        )
901

902
        try:
Antoni Baum's avatar
Antoni Baum committed
903
904
            async for request_output in stream:
                yield request_output
905
        except (Exception, asyncio.CancelledError) as e:
Antoni Baum's avatar
Antoni Baum committed
906
907
            self._abort(request_id)
            raise e
908

Antoni Baum's avatar
Antoni Baum committed
909
910
    async def abort(self, request_id: str) -> None:
        """Abort a request.
911

Antoni Baum's avatar
Antoni Baum committed
912
913
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
914

Antoni Baum's avatar
Antoni Baum committed
915
916
917
        Args:
            request_id: The unique id of the request.
        """
918
919
920
921
922
923
924
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
925
        return self._abort(request_id)
926

Antoni Baum's avatar
Antoni Baum committed
927
    def _abort(self, request_id: str) -> None:
928
929
930
931
932
933
934
935
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
936
937
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
938

939
940
941
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
942
            return await self.engine.get_model_config.remote()  # type: ignore
943
944
945
        else:
            return self.engine.get_model_config()

946
947
948
949
950
951
952
953
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_decoding_config.remote(  # type: ignore
            )
        else:
            return self.engine.get_decoding_config()

954
955
956
957
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
958
        if self.engine_use_ray:
959
960
            await self.engine.do_log_stats.remote(  # type: ignore
                scheduler_outputs, model_output)
961
962
        else:
            self.engine.do_log_stats()
963

964
    async def check_health(self) -> None:
965
966
967
968
969
970
971
972
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

        if self.engine_use_ray:
            try:
973
                await self.engine.check_health.remote()  # type: ignore
974
975
976
977
            except ray.exceptions.RayActorError as e:
                raise RuntimeError("Engine is dead.") from e
        else:
            await self.engine.check_health_async()
978
        logger.debug("Health check took %fs", time.perf_counter() - t)
979
980
981
982
983
984
985

    async def is_tracing_enabled(self) -> bool:
        if self.engine_use_ray:
            return await self.engine.is_tracing_enabled.remote(  # type: ignore
            )
        else:
            return self.engine.is_tracing_enabled()
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.add_logger.remote(  # type: ignore
                    logger_name=logger_name, logger=logger))
        else:
            self.engine.add_logger(logger_name=logger_name, logger=logger)

    def remove_logger(self, logger_name: str) -> None:
        if self.engine_use_ray:
            ray.get(
                self.engine.remove_logger.remote(  # type: ignore
                    logger_name=logger_name))
        else:
            self.engine.remove_logger(logger_name=logger_name)