async_llm.py 24.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
from collections.abc import AsyncGenerator, Mapping
5
from copy import copy
6
from typing import Any, Optional, Union
7

8
9
import numpy as np

10
import vllm.envs as envs
11
12
13
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
14
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
15
from vllm.inputs import PromptType
16
from vllm.inputs.preprocess import InputPreprocessor
17
18
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
19
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
20
from vllm.outputs import PoolingRequestOutput, RequestOutput
21
22
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
23
from vllm.sampling_params import SamplingParams
24
25
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
26
27
28
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
29
from vllm.utils import Device, cdiv
30
from vllm.v1.engine import EngineCoreRequest
31
from vllm.v1.engine.core_client import EngineCoreClient
32
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
33
34
from vllm.v1.engine.output_processor import (OutputProcessor,
                                             RequestOutputCollector)
35
from vllm.v1.engine.parallel_sampling import ParentRequest
36
from vllm.v1.engine.processor import Processor
37
from vllm.v1.executor.abstract import Executor
38
39
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
                                     setup_default_loggers)
40
from vllm.v1.metrics.prometheus import shutdown_prometheus
41
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
42
43
44
45
46
47
48
49
50

logger = init_logger(__name__)


class AsyncLLM(EngineClient):

    def __init__(
        self,
        vllm_config: VllmConfig,
51
        executor_class: type[Executor],
52
53
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
54
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
55
56
57
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
58
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
59
60
        client_addresses: Optional[dict[str, str]] = None,
        client_index: int = 0,
61
    ) -> None:
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
82
83
84
85
86
87
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
88

89
90
91
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

92
        self.model_config = vllm_config.model_config
93
        self.vllm_config = vllm_config
94
95
        self.log_requests = log_requests
        self.log_stats = log_stats
96
97

        # Set up stat loggers; independent set for each DP rank.
98
99
100
101
102
103
        self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
            vllm_config=vllm_config,
            log_stats=self.log_stats,
            engine_num=vllm_config.parallel_config.data_parallel_size,
            custom_stat_loggers=stat_loggers,
        )
104
105
106
107
108

        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
109
            lora_config=vllm_config.lora_config)
110
111

        # Processor (converts Inputs --> EngineCoreRequests).
112
        self.processor = Processor(
113
            vllm_config=vllm_config,
114
            tokenizer=self.tokenizer,
115
            mm_registry=mm_registry,
116
        )
117

118
119
120
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
                                                log_stats=self.log_stats)
121
122

        # EngineCore (starts the engine in background process).
123
124

        self.engine_core = EngineCoreClient.make_async_mp_client(
125
126
            vllm_config=vllm_config,
            executor_class=executor_class,
127
            log_stats=self.log_stats,
128
129
            client_addresses=client_addresses,
            client_index=client_index,
130
        )
131
132
133
        if self.stat_loggers:
            for stat_logger in self.stat_loggers[0]:
                stat_logger.log_engine_initialized()
134
        self.output_handler: Optional[asyncio.Task] = None
135
136
137
138
139
140
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
141

142
143
144
145
146
147
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
148
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
149
150
        disable_log_requests: bool = False,
        disable_log_stats: bool = False,
151
152
        client_addresses: Optional[dict[str, str]] = None,
        client_index: int = 0,
153
154
155
156
157
158
159
160
161
162
163
164
165
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
166
            stat_loggers=stat_loggers,
167
168
169
            log_requests=not disable_log_requests,
            log_stats=not disable_log_stats,
            usage_context=usage_context,
170
171
            client_addresses=client_addresses,
            client_index=client_index,
172
173
        )

174
175
176
177
178
179
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
180
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
181
    ) -> "AsyncLLM":
182
183
184
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
185
        vllm_config = engine_args.create_engine_config(usage_context)
186
        executor_class = Executor.get_class(vllm_config)
187
188
189
190
191
192
193
194
195

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
196
            stat_loggers=stat_loggers,
197
198
        )

199
200
201
    def __del__(self):
        self.shutdown()

202
203
204
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

205
206
        shutdown_prometheus()

207
208
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
209
210
211
212
213
214
215
216
217
218
219

        if handler := getattr(self, "output_handler", None):
            handler.cancel()

    async def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
220
        tokenization_kwargs: Optional[dict[str, Any]] = None,
221
222
223
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
224
        data_parallel_rank: Optional[int] = None,
225
    ) -> RequestOutputCollector:
226
227
        """Add new request to the AsyncLLM."""

228
229
230
        if self.errored:
            raise EngineDeadError()

231
        is_pooling = isinstance(params, PoolingParams)
232
233
234

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
235

236
        # Convert Input --> Request.
237
238
        prompt_str, request = self.processor.process_inputs(
            request_id, prompt, params, arrival_time, lora_request,
239
            tokenization_kwargs, trace_headers, prompt_adapter_request,
240
            priority, data_parallel_rank)
241

242
        if is_pooling or params.n == 1:
243
            await self._add_request(request, prompt_str, None, 0, queue)
244
245
246
247
            return queue

        # Fan out child requests (for n>1).
        parent_request = ParentRequest(request_id, params)
248
        for idx in range(params.n):
249
            request_id, params = parent_request.get_child_info(idx)
250
            child_request = request if idx == params.n - 1 else copy(request)
251
252
            child_request.request_id = request_id
            child_request.sampling_params = params
253
254
            await self._add_request(child_request, prompt_str, parent_request,
                                    idx, queue)
255
        return queue
256

257
    async def _add_request(self, request: EngineCoreRequest,
258
                           prompt: Optional[str],
259
                           parent_req: Optional[ParentRequest], index: int,
260
                           queue: RequestOutputCollector):
261

262
        # Add the request to OutputProcessor (this process).
263
264
        self.output_processor.add_request(request, prompt, parent_req, index,
                                          queue)
265

266
267
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
268

269
270
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
271
272
273
274
275
276

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
277
    async def generate(
278
279
280
281
282
283
284
285
        self,
        prompt: PromptType,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
286
        data_parallel_rank: Optional[int] = None,
287
288
289
290
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
291
            * 2) Processing the Input.
292
293
294
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

295
296
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
297
298
299
300
301
302
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

303
304
305
306
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
307
            self._run_output_handler()
308
309

            q = await self.add_request(
310
311
312
313
314
315
316
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
317
                data_parallel_rank=data_parallel_rank,
318
            )
319

320
321
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
322
323
            finished = False
            while not finished:
324
325
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
326
                out = q.get_nowait() or await q.get()
327

328
                # Note: both OutputProcessor and EngineCore handle their
329
                # own request cleanup based on finished.
330
                finished = out.finished
331
332
                yield out

333
        # If the request is disconnected by the client, generate()
334
335
336
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
337
            await self.abort(request_id)
338
339
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
340
            raise
341

342
343
344
345
346
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
347

348
349
350
351
352
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
353

354
        # Unexpected error in the generate() task (possibly recoverable).
355
        except Exception as e:
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
        stat_loggers = self.stat_loggers if log_stats else None

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

                    iteration_stats = IterationStats() if (
                        log_stats and num_outputs) else None

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
                    if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
                        slices = (outputs.outputs, )
                    else:
                        slices = np.array_split(
                            outputs.outputs,
                            cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
                            outputs_slice, outputs.timestamp, iteration_stats)
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
                            processed_outputs.reqs_to_abort)

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
                    if stat_loggers:
                        AsyncLLM._record_stats(
                            stat_loggers[outputs.engine_index],
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
423
424

    async def abort(self, request_id: str) -> None:
425
        """Abort RequestId in OutputProcessor and EngineCore."""
426

427
        request_ids = self.output_processor.abort_requests((request_id, ))
428
429
        await self.engine_core.abort_requests_async(request_ids)

430
431
        if self.log_requests:
            logger.info("Aborted request %s.", request_id)
432

433
    @staticmethod
434
    def _record_stats(
435
        stat_loggers: list[StatLoggerBase],
436
        scheduler_stats: Optional[SchedulerStats],
437
        iteration_stats: Optional[IterationStats],
438
    ):
439
440
441
        """static so that it can be used from the output_handler task
        without a circular ref to AsyncLLM."""
        for stat_logger in stat_loggers:
442
443
            stat_logger.record(scheduler_stats=scheduler_stats,
                               iteration_stats=iteration_stats)
444

445
    async def encode(
446
447
448
449
450
451
452
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            self._run_output_handler()

            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
522

523
524
525
    async def get_vllm_config(self) -> VllmConfig:
        return self.vllm_config

526
527
528
529
530
531
    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def get_decoding_config(self):
        raise ValueError("Not Supported on V1 yet.")

532
533
534
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.processor.input_preprocessor

535
536
537
538
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
539
        return self.tokenizer.get_lora_tokenizer(lora_request)
540
541
542
543
544
545
546
547
548

    async def is_tracing_enabled(self) -> bool:
        return False

    async def do_log_stats(
        self,
        scheduler_outputs=None,
        model_output=None,
    ) -> None:
549
550
551
        for loggers in self.stat_loggers:
            for stat_logger in loggers:
                stat_logger.log()
552
553
554

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
555
556
        if self.errored:
            raise self.dead_error
557
558

    async def start_profile(self) -> None:
559
        await self.engine_core.profile_async(True)
560
561

    async def stop_profile(self) -> None:
562
        await self.engine_core.profile_async(False)
563

564
565
566
567
568
    async def reset_mm_cache(self) -> None:
        self.processor.mm_registry.reset_processor_cache()
        self.processor.mm_input_cache_client.reset()
        await self.engine_core.reset_mm_cache_async()

569
570
571
572
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
573
574
        await self.engine_core.reset_prefix_cache_async()

575
576
577
    async def sleep(self, level: int = 1) -> None:
        await self.engine_core.sleep_async(level)

578
579
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        await self.engine_core.wake_up_async(tags)
580

581
582
583
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

584
    async def add_lora(self, lora_request: LoRARequest) -> bool:
585
        """Load a new LoRA adapter into the engine for future requests."""
586
587
588
589
590
591
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

592
    async def list_loras(self) -> set[int]:
593
594
595
596
597
598
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
599

600
601
602
603
604
605
606
607
608
609
610
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
            method, timeout, args, kwargs)

611
612
    @property
    def is_running(self) -> bool:
613
614
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
615
616
617

    @property
    def is_stopped(self) -> bool:
618
        return self.errored
619
620
621

    @property
    def errored(self) -> bool:
622
        return self.engine_core.resources.engine_dead or not self.is_running
623
624
625

    @property
    def dead_error(self) -> BaseException:
626
        return EngineDeadError()