"tools/vscode:/vscode.git/clone" did not exist on "1f8b7c536be40975573eeebf36204286cfb4e4e9"
async_llm_engine.py 41.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import time
6
import weakref
Antoni Baum's avatar
Antoni Baum committed
7
from functools import partial
8
9
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
                    Mapping, Optional, Set, Tuple, Type, Union)
10
from weakref import ReferenceType
11

12
import vllm.envs as envs
13
14
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig, VllmConfig)
15
from vllm.core.scheduler import SchedulerOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.engine.arg_utils import AsyncEngineArgs
17
from vllm.engine.async_timeout import asyncio_timeout
18
from vllm.engine.llm_engine import LLMEngine
19
from vllm.engine.metrics_types import StatLoggerBase
20
from vllm.engine.protocol import EngineClient
21
from vllm.executor.executor_base import ExecutorBase
22
from vllm.inputs import PromptType
23
from vllm.inputs.preprocess import InputPreprocessor
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.logger import init_logger
25
from vllm.lora.request import LoRARequest
26
from vllm.model_executor.layers.sampler import SamplerOutput
27
from vllm.outputs import PoolingRequestOutput, RequestOutput
28
from vllm.pooling_params import PoolingParams
29
from vllm.sampling_params import SamplingParams
30
from vllm.sequence import ExecuteModelRequest
31
from vllm.transformers_utils.tokenizer import AnyTokenizer
yhu422's avatar
yhu422 committed
32
from vllm.usage.usage_lib import UsageContext
33
from vllm.utils import Device, deprecate_kwargs, weak_bind
34
35

logger = init_logger(__name__)
36
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
37

Antoni Baum's avatar
Antoni Baum committed
38

39
40
41
42
class AsyncEngineDeadError(RuntimeError):
    pass


43
44
45
46
47
48
49
def _log_task_completion(task: asyncio.Task,
                         error_callback: Callable[[Exception], None]) -> None:
    """This function is only intended for the `engine.run_engine_loop()` task.

    In particular, that task runs a `while True` loop that can only exit if
    there is an exception.
    """
50
51

    exception = None
52
    try:
53
54
55
56
57
58
59
60
        return_value = task.result()
        raise AssertionError(
            f"The engine background task should never finish without an "
            f"exception. {return_value}")
    except asyncio.exceptions.CancelledError:
        # We assume that if the task is cancelled, we are gracefully shutting
        # down. This should only happen on program exit.
        logger.info("Engine is gracefully shutting down.")
61
62
63
64
65
    except Exception as e:
        exception = e
        logger.error("Engine background task failed", exc_info=e)
        error_callback(exception)
        raise AsyncEngineDeadError(
66
            "Task finished unexpectedly. This should never happen! "
67
            "Please open an issue on GitHub. See stack trace above for the "
68
            "actual cause.") from e
69
70


71
72
73
STOP_ITERATION = Exception()  # Sentinel


Antoni Baum's avatar
Antoni Baum committed
74
class AsyncStream:
75
76
    """A stream of RequestOutputs for a request that can be iterated over
    asynchronously via an async generator."""
Antoni Baum's avatar
Antoni Baum committed
77

78
    def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
Antoni Baum's avatar
Antoni Baum committed
79
        self.request_id = request_id
80
        self._cancel = cancel
81
        self._queue: asyncio.Queue = asyncio.Queue()
Antoni Baum's avatar
Antoni Baum committed
82
83
        self._finished = False

84
    def put(self, item: Union[RequestOutput, Exception]) -> None:
85
86
        if not self._finished:
            self._queue.put_nowait(item)
Antoni Baum's avatar
Antoni Baum committed
87

88
89
90
91
    def finish(
        self,
        exception: Optional[Union[BaseException, Type[BaseException]]] = None,
    ) -> None:
92
93
94
        if not self._finished:
            self._finished = True
            self._queue.put_nowait(
95
                exception if self._is_raisable(exception) else STOP_ITERATION)
Antoni Baum's avatar
Antoni Baum committed
96
97
98
99
100

    @property
    def finished(self) -> bool:
        return self._finished

101
    async def generator(self) -> AsyncGenerator[RequestOutput, None]:
102
        try:
103
            while True:
104
                result = await self._queue.get()
105
                if self._is_raisable(result):
106
107
108
109
110
111
112
                    if result == STOP_ITERATION:
                        return
                    raise result
                yield result
        except GeneratorExit:
            self._cancel(self.request_id)
            raise asyncio.CancelledError from None
Antoni Baum's avatar
Antoni Baum committed
113

114
115
116
117
118
119
    @staticmethod
    def _is_raisable(value: Any):
        return isinstance(value, BaseException) or \
                (isinstance(value, type) and \
                 issubclass(value, BaseException))

Antoni Baum's avatar
Antoni Baum committed
120

121
122
123
124
125
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
126
        self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
127
128
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
129
        self.new_requests_event = asyncio.Event()
130
131
132
133

    def __contains__(self, item):
        return item in self._request_streams

134
135
    def __len__(self) -> int:
        return len(self._request_streams)
136
137
138
139
140
141
142

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
143
            self.abort_request(request_id, exception=exc)
144
        else:
145
            # NB: tuple() used here because self.abort_request pops the stream
146
            # out of self._request_streams, so we can't iterate on it directly
147
148
            for rid in tuple(self._request_streams.keys()):
                self.abort_request(rid, exception=exc)
149
150

    def process_request_output(self,
151
                               request_output: RequestOutput,
152
153
154
155
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id
156
        finished = request_output.finished
157

158
159
160
161
        if finished:
            stream = self._request_streams.pop(request_id, None)
        else:
            stream = self._request_streams.get(request_id)
162
163
        # Guard against a KeyError which can occur if the request was aborted
        # while the output was generated
164
        if stream is not None:
165
            stream.put(request_output)
166
167
168
169
170
            if finished:
                stream.finish()

        if verbose and finished:
            logger.info("Finished request %s.", request_id)
171

172
173
    def process_exception(self,
                          request_id: str,
174
                          exception: BaseException,
175
176
177
178
                          *,
                          verbose: bool = False) -> None:
        """Propagate an exception from the engine."""
        if verbose:
179
            logger.info("Finished request %s.", request_id)
180
        self.abort_request(request_id, exception=exception)
181

182
183
184
185
    def add_request(self,
                    request_id: str,
                    *,
                    verbose: bool = False,
186
187
188
189
190
191
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

192
193
        abort_request = partial(self.abort_request, verbose=verbose)
        stream = AsyncStream(request_id, abort_request)
194
195
196
197
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
198
199
200

        self.new_requests_event.set()

201
202
203
        if verbose:
            logger.info("Added request %s.", request_id)

204
205
        return stream

206
207
208
    def abort_request(self,
                      request_id: str,
                      *,
209
210
                      exception: Optional[Union[BaseException,
                                                Type[BaseException]]] = None,
211
                      verbose: bool = False) -> None:
212
213
        """Abort a request during next background loop iteration."""
        if verbose:
214
            logger.info("Aborted request %s.", request_id)
215

216
        self._aborted_requests.put_nowait(request_id)
217

218
219
        stream = self._request_streams.pop(request_id, None)
        if stream is not None:
220
            stream.finish(exception=exception)
221

222
    def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
223
224
        """Get the new requests and finished requests to be
        sent to the engine."""
225
        new_requests: List[Dict] = []
226
227
        finished_requests: Set[str] = set()

228
229
        while not self._aborted_requests.empty():
            request_id = self._aborted_requests.get_nowait()
230
231
232
233
            finished_requests.add(request_id)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
234
235
            request_id = stream.request_id
            if request_id in finished_requests:
236
                # The request has already been aborted.
237
238
239
240
241
                stream.finish(asyncio.CancelledError)
                finished_requests.discard(request_id)
            else:
                self._request_streams[request_id] = stream
                new_requests.append(new_request)
242
243

        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
244

245
    async def wait_for_new_requests(self):
246
247
248
249
250
251
        if not self.has_new_requests():
            await self.new_requests_event.wait()
        self.new_requests_event.clear()

    def has_new_requests(self):
        return not self._new_requests.empty()
252

Antoni Baum's avatar
Antoni Baum committed
253
254
255
256

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

257
258
259
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

260
    async def step_async(self, virtual_engine: int) -> List[RequestOutput]:
Antoni Baum's avatar
Antoni Baum committed
261
262
263
264
265
266
267
268
269
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
270
271
272
273
274
        # these are cached outputs from previous iterations. None if on first
        # iteration
        cached_outputs = self.cached_scheduler_outputs[virtual_engine]
        seq_group_metadata_list = cached_outputs.seq_group_metadata_list
        scheduler_outputs = cached_outputs.scheduler_outputs
275
276
        allow_async_output_proc = cached_outputs.allow_async_output_proc

277
278
        ctx = self.scheduler_contexts[virtual_engine]

279
280
281
        # Clear outputs for each new scheduler iteration
        ctx.request_outputs.clear()

282
283
284
285
        # skip the scheduler if there are any remaining steps in the seq groups.
        # This ensures that the scheduler is only called again when the current
        # batch has completed.
        if not self._has_remaining_steps(seq_group_metadata_list):
286

287
            # Schedule iteration
288
289
290
291
            (seq_group_metadata_list, scheduler_outputs,
             allow_async_output_proc
             ) = self.scheduler[virtual_engine].schedule()

292
293
            ctx.seq_group_metadata_list = seq_group_metadata_list
            ctx.scheduler_outputs = scheduler_outputs
294

295
296
297
298
299
            if not scheduler_outputs.is_empty():
                # this will cause mamba_cache/minimax_cache failed
                # to release finished_requests_ids of the last steps
                finished_requests_ids = self.scheduler[
                    virtual_engine].get_and_reset_finished_requests_ids()
300

301
            # Maybe switch from async mode to sync mode
302
            if not allow_async_output_proc and len(ctx.output_queue) > 0:
303
                self._process_model_outputs(ctx=ctx)
304

305
306
        else:
            finished_requests_ids = list()
307
308
309

        assert seq_group_metadata_list is not None
        assert scheduler_outputs is not None
Antoni Baum's avatar
Antoni Baum committed
310

311
        if not scheduler_outputs.is_empty():
312
313
314
315
316
317
318
319

            # Check if we have a cached last_output from the previous iteration.
            # For supporting PP this is probably the best way to pass the
            # sampled_token_ids, as a separate broadcast over all the PP stages
            # will cause one virtual engine's microbatch to block the pipeline.
            last_sampled_token_ids = \
                self._get_last_sampled_token_ids(virtual_engine)

320
321
322
323
324
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
325
                virtual_engine=virtual_engine,
326
327
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
328
329
330
331
                finished_requests_ids=finished_requests_ids,
                # We use ExecuteModelRequest to pass the last sampled_token_ids
                # to each of the non-last PP stages for in-place prepare_input.
                last_sampled_token_ids=last_sampled_token_ids)
332
333

            if allow_async_output_proc:
334
335
                execute_model_req.async_callback = self.async_callbacks[
                    virtual_engine]
336

337
            # Execute the model.
338
            outputs = await self.model_executor.execute_model_async(
339
                execute_model_req)
340

341
        else:
342
343
            if len(ctx.output_queue) > 0:
                self._process_model_outputs(ctx=ctx)
344
            outputs = []
Antoni Baum's avatar
Antoni Baum committed
345

346
        if not self._has_remaining_steps(seq_group_metadata_list):
347
            # is_first_step_output is True only when the num_steps of all
348
            # the sequences are 1.
349
350
351
            is_first_step_output: bool = False if not seq_group_metadata_list \
                else seq_group_metadata_list[0].state.num_steps == 1

352
353
354
355
            ctx.append_output(outputs=outputs,
                              seq_group_metadata_list=seq_group_metadata_list,
                              scheduler_outputs=scheduler_outputs,
                              is_async=allow_async_output_proc,
356
357
                              is_last_step=True,
                              is_first_step_output=is_first_step_output)
358

359
            if outputs and allow_async_output_proc:
360
                assert len(
361
                    outputs
362
363
                ) == 1, "Async postprocessor expects only a single output set"
                self._advance_to_next_step(
364
                    outputs[0], seq_group_metadata_list,
365
                    scheduler_outputs.scheduled_seq_groups)
366
367

            if not allow_async_output_proc:
368
                self._process_model_outputs(ctx=ctx)
369
370

                # Log stats.
371
                self.do_log_stats(scheduler_outputs, outputs)
372
373
374
375
376

                # Tracing
                self.do_tracing(scheduler_outputs)

        else:
377
            # Multi-step case
378
            return ctx.request_outputs
379
380
381
382

        if not self.has_unfinished_requests():
            # Drain async postprocessor (if exists)
            if len(ctx.output_queue) > 0:
383
                self._process_model_outputs(ctx=ctx)
384
            assert len(ctx.output_queue) == 0
385

386
        return ctx.request_outputs
387

388
389
390
391
    async def stop_remote_worker_execution_loop_async(self) -> None:
        """Stop the remote worker execution loop."""
        await self.model_executor.stop_remote_worker_execution_loop_async()

392
393
394
395
396
397
    async def get_tokenizer_async(self,
                                  lora_request: Optional[LoRARequest] = None
                                  ) -> AnyTokenizer:
        return await (
            self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))

398
399
400
401
    async def add_request_async(
        self,
        request_id: str,
        prompt: PromptType,
402
        params: SamplingParams,
403
404
405
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
406
        priority: int = 0,
407
        data_parallel_rank: Optional[int] = None,
408
        tokenization_kwargs: Optional[dict[str, Any]] = None,
409
    ) -> None:
410
411
412
413
        """
        Async version of
        [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request].
        """
414
415
416
        if lora_request is not None and not self.lora_config:
            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                             "not enabled!")
417
418
419
        if priority != 0 and not self.scheduler_config.policy == "priority":
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")
420
421
        if arrival_time is None:
            arrival_time = time.time()
422

423
424
425
426
        if data_parallel_rank is not None:
            raise ValueError("Targeting data_parallel_rank only supported "
                             "in v1 client.")

427
428
429
430
431
432
433
434
        if (isinstance(prompt, dict)
                and prompt.get("prompt_embeds", None) is not None
                and not prompt.get("prompt_token_ids", None)):
            # We use the -2 dimension (instead of 0) in case a batched input
            # of batch size 1 is passed in.
            prompt["prompt_token_ids"] = [0
                                          ] * prompt["prompt_embeds"].shape[-2]

435
        processed_inputs = await self.input_preprocessor.preprocess_async(
436
            prompt,
437
            lora_request=lora_request,
438
            tokenization_kwargs=tokenization_kwargs,
439
        )
440
441

        self._add_processed_request(
442
            request_id=request_id,
443
444
445
446
            processed_inputs=processed_inputs,
            params=params,
            arrival_time=arrival_time,
            lora_request=lora_request,
447
            trace_headers=trace_headers,
448
            priority=priority,
449
        )
450

451
452
    async def check_health_async(self) -> None:
        self.model_executor.check_health()
453

454
455
456
457
458
459
460
    async def collective_rpc_async(self,
                                   method: str,
                                   timeout: Optional[float] = None,
                                   args: tuple = (),
                                   kwargs: Optional[dict] = None):
        raise NotImplementedError

461

462
class AsyncLLMEngine(EngineClient):
463
    """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
464

465
466
467
468
469
470
    This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to
    make it asynchronous. It uses asyncio to create a background loop that keeps
    processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked
    by the generate method when there are requests in the waiting queue. The
    generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine]
    to the caller.
471
472

    Args:
473
        log_requests: Whether to log the requests.
474
475
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
476
477
        *args: Arguments for [`LLMEngine`][vllm.LLMEngine].
        **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine].
478
    """
479

Antoni Baum's avatar
Antoni Baum committed
480
481
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

482
    def __init__(self,
483
                 *args: Any,
484
                 log_requests: bool = True,
485
                 start_engine_loop: bool = True,
486
                 **kwargs: Any) -> None:
487
488
489
490
491
492
493
        if envs.VLLM_USE_V1:
            raise ValueError(
                "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

494
        self.log_requests = log_requests
495
        self.engine = self._engine_class(*args, **kwargs)
Antoni Baum's avatar
Antoni Baum committed
496

497
498
499
        # This ensures quick processing of request outputs
        # so the append to asyncio queues is not delayed,
        # especially for multi-step.
500
501
502
        self.use_process_request_outputs_callback = (
            self.engine.model_config.use_async_output_proc)

503
504
        if self.use_process_request_outputs_callback:
            self.engine.process_request_outputs_callback = \
505
                weak_bind(self.process_request_outputs)
506

507
        self.background_loop: Optional[asyncio.Future] = None
508
509
510
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
511
        self._background_loop_unshielded: Optional[asyncio.Task] = None
512
        self.start_engine_loop = start_engine_loop
513
        self._errored_with: Optional[BaseException] = None
Antoni Baum's avatar
Antoni Baum committed
514

515
516
517
        # Lazy initialized fields
        self._request_tracker: RequestTracker

518
519
520
521
522
    def __del__(self):
        if rt := getattr(self, "request_tracker", None):
            # Wake up engine loop so that it will exit cleanly
            rt.new_requests_event.set()

523
    @classmethod
524
525
526
    def _get_executor_cls(cls,
                          engine_config: VllmConfig) -> Type[ExecutorBase]:
        return LLMEngine._get_executor_cls(engine_config)
527

528
    @classmethod
529
530
531
532
533
    @deprecate_kwargs(
        "disable_log_requests",
        additional_message=("This argument will have no effect. "
                            "Use `enable_log_requests` instead."),
    )
534
    def from_vllm_config(
535
536
537
538
539
540
541
542
            cls,
            vllm_config: VllmConfig,
            start_engine_loop: bool = True,
            usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
            stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
            enable_log_requests: bool = False,
            disable_log_stats: bool = False,
            disable_log_requests: bool = True,  # Deprecated, will be removed
543
544
545
546
547
548
549
    ) -> "AsyncLLMEngine":
        """Create an AsyncLLMEngine from the EngineArgs."""

        return cls(
            vllm_config=vllm_config,
            executor_class=cls._get_executor_cls(vllm_config),
            start_engine_loop=start_engine_loop,
550
            log_requests=enable_log_requests,
551
552
553
554
555
            log_stats=not disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

556
557
558
559
560
561
562
563
564
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
    ) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
565
566
567
568
569
570
571
572
573
574

        vllm_config = engine_args.create_engine_config(usage_context)

        async_engine_cls = cls
        if envs.VLLM_USE_V1:
            from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
            async_engine_cls = V1AsyncLLMEngine

        return async_engine_cls.from_vllm_config(
            vllm_config=vllm_config,
yhu422's avatar
yhu422 committed
575
576
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
577
            stat_loggers=stat_loggers,
578
            disable_log_stats=engine_args.disable_log_stats,
579
            enable_log_requests=engine_args.enable_log_requests,
yhu422's avatar
yhu422 committed
580
        )
581

582
583
    @property
    def is_running(self) -> bool:
584
        return (self.background_loop is not None
585
                and self._background_loop_unshielded is not None
586
587
588
589
                and not self._background_loop_unshielded.done())

    @property
    def is_stopped(self) -> bool:
590
591
        return self.errored or (self.background_loop is not None and
                                self._background_loop_unshielded is not None
592
593
594
595
596
597
                                and self._background_loop_unshielded.done())

    @property
    def errored(self) -> bool:
        return self._errored_with is not None

598
    @property
599
600
601
602
603
604
    def dead_error(self) -> BaseException:
        return AsyncEngineDeadError(
            "Background loop is not running. If it was running, "
            "inspect the output to find the stacktrace of the "
            "error that caused the background loop to stop "
            "(AsyncEngineDeadError).")
605

606
607
608
609
610
611
    def set_errored(self, exc: Exception) -> None:
        self._errored_with = exc

    def _error_callback(self, exc: Exception) -> None:
        self.set_errored(exc)
        self._request_tracker.propagate_exception(exc)
612

613
614
615
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.engine.input_preprocessor

616
617
618
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
619
    ) -> AnyTokenizer:
620
        return await self.engine.get_tokenizer_async(lora_request)
621

622
    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
623
        """Start the background loop."""
624
625
626
        if self.errored:
            raise AsyncEngineDeadError(
                "Background loop has errored already.") from self._errored_with
627
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
628
            raise RuntimeError("Background loop is already running.")
629
630
        # Initialize the RequestTracker here so it uses the right event loop.
        self._request_tracker = RequestTracker()
631
632

        self._background_loop_unshielded = asyncio.get_event_loop(
633
        ).create_task(self.run_engine_loop(weakref.ref(self)))
634
        self._background_loop_unshielded.add_done_callback(
635
            partial(_log_task_completion, error_callback=self._error_callback))
636
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
637

638
639
640
641
642
643
644
645
646
647
648
649
650
651
    def shutdown_background_loop(self) -> None:
        """
        Shut down the background loop.

        This method needs to be called during cleanup to remove
        references to `self` and properly GC the resources held
        by the async LLM engine (e.g., the executors as well as
        their resources).
        """
        if self._background_loop_unshielded is not None:
            self._background_loop_unshielded.cancel()
            self._background_loop_unshielded = None
        self.background_loop = None

652
    async def engine_step(self, virtual_engine: int) -> bool:
653
654
655
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
656

657
658
        new_requests, aborted_requests = (
            self._request_tracker.get_new_and_aborted_requests())
659
660
661

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
662
            try:
663
                await self.engine.add_request_async(**new_request)
664
665
666
667
668
669
670
            except ValueError as e:
                # TODO: use a vLLM specific error for failed validation
                self._request_tracker.process_exception(
                    new_request["request_id"],
                    e,
                    verbose=self.log_requests,
                )
671

672
673
        if aborted_requests:
            await self._engine_abort(aborted_requests)
674

675
        request_outputs = await self.engine.step_async(virtual_engine)
676

Antoni Baum's avatar
Antoni Baum committed
677
        # Put the outputs into the corresponding streams.
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
        # If used as a callback, then already invoked inside
        # LLMEngine's _process_model_outputs
        if not self.use_process_request_outputs_callback:
            all_finished = self.process_request_outputs(request_outputs)
        else:
            # For callback case, we only need to detect when all
            # requests are finished
            all_finished = all(request_output.finished
                               for request_output in request_outputs)

        return not all_finished

    def process_request_outputs(self, request_outputs) -> bool:
        # Put the outputs into the corresponding streams.
        all_finished = True
693
        for request_output in request_outputs:
694
            self._request_tracker.process_request_output(
695
                request_output, verbose=self.log_requests)
696
            all_finished = all_finished and request_output.finished
Antoni Baum's avatar
Antoni Baum committed
697

698
        return all_finished
699

Antoni Baum's avatar
Antoni Baum committed
700
    async def _engine_abort(self, request_ids: Iterable[str]):
701
        self.engine.abort_request(request_ids)
Antoni Baum's avatar
Antoni Baum committed
702

703
704
705
706
    @staticmethod
    async def run_engine_loop(engine_ref: ReferenceType):
        """We use a weakref to the engine so that the running loop
        doesn't prevent the engine being garbage collected."""
707
        engine: Optional[AsyncLLMEngine] = engine_ref()
708
709
710
        if not engine:
            return

711
        pipeline_parallel_size = \
712
                engine.engine.parallel_config.pipeline_parallel_size
713
        has_requests_in_progress = [False] * pipeline_parallel_size
Antoni Baum's avatar
Antoni Baum committed
714
        while True:
715
            if not any(has_requests_in_progress):
716
                logger.debug("Waiting for new requests...")
717
718
719
720
721
722
                # Stop the execute model loop in parallel workers until there
                # are more requests to process. This avoids waiting
                # indefinitely in torch.distributed ops which may otherwise
                # timeout, and unblocks the RPC thread in the workers so that
                # they can process any other queued control plane messages,
                # such as add/remove lora adapters.
723
724
725
726
727
728
729
730
731
732
733
734
                await engine.engine.stop_remote_worker_execution_loop_async()
                request_tracker = engine._request_tracker
                # Allow engine to be garbage collected while
                # waiting for new requests
                del engine
                await asyncio.sleep(0)
                if engine_ref() is None:
                    return
                await request_tracker.wait_for_new_requests()
                engine = engine_ref()
                if not engine:
                    return
735
                logger.debug("Got new requests!")
736
                requests_in_progress = [
737
                    asyncio.create_task(engine.engine_step(ve))
738
739
740
                    for ve in range(pipeline_parallel_size)
                ]
                has_requests_in_progress = [True] * pipeline_parallel_size
741
742
743
744

            # Abort if iteration takes too long due to unrecoverable errors
            # (eg. NCCL timeouts).
            try:
745
                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
746
747
748
749
750
751
752
753
                    done, _ = await asyncio.wait(
                        requests_in_progress,
                        return_when=asyncio.FIRST_COMPLETED)
                    for _ in range(pipeline_parallel_size):
                        await asyncio.sleep(0)
                for task in done:
                    result = task.result()
                    virtual_engine = requests_in_progress.index(task)
754
                    has_unfinished_requests = (
755
756
                        engine.engine.
                        has_unfinished_requests_for_virtual_engine(
757
                            virtual_engine))
758
759
760
                    if result or has_unfinished_requests:
                        requests_in_progress[virtual_engine] = (
                            asyncio.create_task(
761
                                engine.engine_step(virtual_engine)))
762
763
764
                        has_requests_in_progress[virtual_engine] = True
                    else:
                        has_requests_in_progress[virtual_engine] = False
765
766
767
            except asyncio.TimeoutError as exc:
                logger.error(
                    "Engine iteration timed out. This should never happen!")
768
                engine.set_errored(exc)
769
                raise
Antoni Baum's avatar
Antoni Baum committed
770
771
            await asyncio.sleep(0)

772
    async def add_request(
773
774
775
        self,
        request_id: str,
        prompt: PromptType,
776
        params: SamplingParams,
777
778
779
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
780
        priority: int = 0,
781
        data_parallel_rank: Optional[int] = None,
782
        tokenization_kwargs: Optional[dict[str, Any]] = None,
783
    ) -> AsyncGenerator[RequestOutput, None]:
784
        if not self.is_running:
785
786
787
788
789
790
791
792
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
793

794
795
796
797
798
        if (priority != 0
                and not self.engine.scheduler_config.policy == "priority"):
            raise ValueError(f"Got priority {priority} but "
                             "Priority scheduling is not enabled.")

799
        stream = self._request_tracker.add_request(
800
            request_id,
801
            verbose=self.log_requests,
802
            prompt=prompt,
803
            params=params,
804
            arrival_time=arrival_time or time.time(),
805
            lora_request=lora_request,
806
            trace_headers=trace_headers,
807
            priority=priority,
808
            data_parallel_rank=data_parallel_rank,
809
            tokenization_kwargs=tokenization_kwargs,
810
        )
Antoni Baum's avatar
Antoni Baum committed
811

812
        return stream.generator()
813

814
    async def generate(
815
        self,
816
        prompt: PromptType,
817
818
        sampling_params: SamplingParams,
        request_id: str,
819
        lora_request: Optional[LoRARequest] = None,
820
        trace_headers: Optional[Mapping[str, str]] = None,
821
        priority: int = 0,
822
        data_parallel_rank: Optional[int] = None,
823
    ) -> AsyncGenerator[RequestOutput, None]:
824
825
826
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
827
828
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
829
830

        Args:
831
832
833
            prompt: The prompt to the LLM. See
                [`PromptType`][vllm.inputs.PromptType] for more details about
                the format of each input.
834
835
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
836
            lora_request: LoRA request to use for generation, if any.
837
            trace_headers: OpenTelemetry trace headers.
838
839
            priority: The priority of the request.
                Only applicable with priority scheduling.
840
841
            data_parallel_rank: The (global) data parallel rank that must
                handle this request. Only applicable if DP is enabled.
842
        Yields:
843
844
            The output `RequestOutput` objects from the LLMEngine
            for the request.
845
846
847
848

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
849
              [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step]
850
851
852
853
854
855
856
857
858
859
860
861
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
862
            >>> # note that engine_args here is AsyncEngineArgs instance
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
889
        """
890
891
892
893
894
895
896
897
        try:
            async for output in await self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    lora_request=lora_request,
                    trace_headers=trace_headers,
                    priority=priority,
898
                    data_parallel_rank=data_parallel_rank,
899
900
901
902
903
            ):
                yield LLMEngine.validate_output(output, RequestOutput)
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
904

905
    def encode(
906
        self,
907
        prompt: PromptType,
908
909
910
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
911
        trace_headers: Optional[Mapping[str, str]] = None,
912
        priority: int = 0,
913
        tokenization_kwargs: Optional[dict[str, Any]] = None,
914
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
915
916
        raise NotImplementedError(
            "Pooling models are not supported in vLLM V0")
917

918
    async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
Antoni Baum's avatar
Antoni Baum committed
919
        """Abort a request.
920

Antoni Baum's avatar
Antoni Baum committed
921
922
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
923

Antoni Baum's avatar
Antoni Baum committed
924
925
926
        Args:
            request_id: The unique id of the request.
        """
927
928
929
        if not isinstance(request_id, str):
            raise RuntimeError("Only single-request abort supported in"
                               " deprecated V0")
930
931
932
933
934
935
936
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
937
        return self._abort(request_id)
938

Antoni Baum's avatar
Antoni Baum committed
939
    def _abort(self, request_id: str) -> None:
940
941
942
943
944
945
946
947
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
948
        self._request_tracker.abort_request(request_id,
949
                                            exception=asyncio.CancelledError,
950
                                            verbose=self.log_requests)
951

952
953
954
955
    async def get_vllm_config(self) -> VllmConfig:
        """Get the vllm configuration of the vLLM engine."""
        return self.engine.get_vllm_config()

956
957
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
958
        return self.engine.get_model_config()
959

960
961
    async def get_parallel_config(self) -> ParallelConfig:
        """Get the parallel configuration of the vLLM engine."""
962
        return self.engine.get_parallel_config()
963

964
965
    async def get_decoding_config(self) -> DecodingConfig:
        """Get the decoding configuration of the vLLM engine."""
966
        return self.engine.get_decoding_config()
967

968
969
    async def get_scheduler_config(self) -> SchedulerConfig:
        """Get the scheduling configuration of the vLLM engine."""
970
        return self.engine.get_scheduler_config()
971
972
973

    async def get_lora_config(self) -> LoRAConfig:
        """Get the lora configuration of the vLLM engine."""
974
        return self.engine.get_lora_config()
975

976
977
978
979
    async def do_log_stats(
            self,
            scheduler_outputs: Optional[SchedulerOutputs] = None,
            model_output: Optional[List[SamplerOutput]] = None) -> None:
980
        self.engine.do_log_stats()
981

982
    async def check_health(self) -> None:
983
984
985
986
987
988
        """Raises an error if engine is unhealthy."""
        t = time.perf_counter()
        logger.debug("Starting health check...")
        if self.is_stopped:
            raise AsyncEngineDeadError("Background loop is stopped.")

989
        await self.engine.check_health_async()
990
        logger.debug("Health check took %fs", time.perf_counter() - t)
991
992

    async def is_tracing_enabled(self) -> bool:
993
        return self.engine.is_tracing_enabled()
994
995

    def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
996
        self.engine.add_logger(logger_name=logger_name, logger=logger)
997
998

    def remove_logger(self, logger_name: str) -> None:
999
        self.engine.remove_logger(logger_name=logger_name)
1000
1001

    async def start_profile(self) -> None:
1002
        self.engine.start_profile()
1003
1004

    async def stop_profile(self) -> None:
1005
        self.engine.stop_profile()
1006

1007
1008
1009
    async def reset_mm_cache(self) -> None:
        self.engine.reset_mm_cache()

1010
1011
1012
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        self.engine.reset_prefix_cache(device)
1013

1014
    async def sleep(self, level: int = 1) -> None:
1015
        await self.reset_prefix_cache()
1016
1017
        self.engine.sleep(level)

1018
1019
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
1020

1021
1022
1023
    async def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

1024
1025
    async def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.engine.add_lora(lora_request)
1026

1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine.collective_rpc_async(method, timeout, args,
                                                      kwargs)

1038
1039

# TODO(v1): Remove this class proxy when V1 goes default.
1040
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
1041
1042
1043
    from vllm.v1.engine.async_llm import AsyncLLM

    AsyncLLMEngine = AsyncLLM  # type: ignore