async_llm.py 21.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import asyncio
3
from collections.abc import AsyncGenerator, Mapping
4
from copy import copy
5
from typing import Any, Optional, Union
6

7
8
import numpy as np

9
import vllm.envs as envs
10
11
12
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
13
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
14
from vllm.inputs import PromptType
15
from vllm.inputs.preprocess import InputPreprocessor
16
17
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
18
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
19
from vllm.outputs import RequestOutput
20
21
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
22
from vllm.sampling_params import SamplingParams
23
24
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
25
26
27
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
28
from vllm.utils import Device, cdiv
29
from vllm.v1.engine import EngineCoreRequest
30
31
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
32
33
from vllm.v1.engine.output_processor import (OutputProcessor,
                                             RequestOutputCollector)
34
from vllm.v1.engine.parallel_sampling import ParentRequest
35
from vllm.v1.engine.processor import Processor
36
from vllm.v1.executor.abstract import Executor
37
38
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
                                     setup_default_loggers)
39
from vllm.v1.metrics.prometheus import shutdown_prometheus
40
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
41
42
43
44
45
46
47
48
49

logger = init_logger(__name__)


class AsyncLLM(EngineClient):

    def __init__(
        self,
        vllm_config: VllmConfig,
50
        executor_class: type[Executor],
51
52
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
53
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
54
55
56
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
57
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
58
59
        client_addresses: Optional[dict[str, str]] = None,
        client_index: int = 0,
60
    ) -> None:
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
81
82
83
84
85
86
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
87

88
89
90
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

91
        self.model_config = vllm_config.model_config
92
        self.vllm_config = vllm_config
93
94
        self.log_requests = log_requests
        self.log_stats = log_stats
95
96

        # Set up stat loggers; independent set for each DP rank.
97
98
99
100
101
102
        self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
            vllm_config=vllm_config,
            log_stats=self.log_stats,
            engine_num=vllm_config.parallel_config.data_parallel_size,
            custom_stat_loggers=stat_loggers,
        )
103
104
105
106
107

        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
108
            lora_config=vllm_config.lora_config)
109
110

        # Processor (converts Inputs --> EngineCoreRequests).
111
        self.processor = Processor(
112
            vllm_config=vllm_config,
113
            tokenizer=self.tokenizer,
114
            mm_registry=mm_registry,
115
        )
116

117
118
119
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
                                                log_stats=self.log_stats)
120
121

        # EngineCore (starts the engine in background process).
122
123
124
125
126
        core_client_class = AsyncMPClient if (
            vllm_config.parallel_config.data_parallel_size
            == 1) else DPAsyncMPClient

        self.engine_core = core_client_class(
127
128
            vllm_config=vllm_config,
            executor_class=executor_class,
129
            log_stats=self.log_stats,
130
131
            client_addresses=client_addresses,
            client_index=client_index,
132
        )
133
134
135
        if self.stat_loggers:
            for stat_logger in self.stat_loggers[0]:
                stat_logger.log_engine_initialized()
136
        self.output_handler: Optional[asyncio.Task] = None
137
138
139
140
141
142
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
143

144
145
146
147
148
149
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
150
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
151
152
        disable_log_requests: bool = False,
        disable_log_stats: bool = False,
153
154
        client_addresses: Optional[dict[str, str]] = None,
        client_index: int = 0,
155
156
157
158
159
160
161
162
163
164
165
166
167
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
168
            stat_loggers=stat_loggers,
169
170
171
            log_requests=not disable_log_requests,
            log_stats=not disable_log_stats,
            usage_context=usage_context,
172
173
            client_addresses=client_addresses,
            client_index=client_index,
174
175
        )

176
177
178
179
180
181
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
182
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
183
    ) -> "AsyncLLM":
184
185
186
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
187
        vllm_config = engine_args.create_engine_config(usage_context)
188
        executor_class = Executor.get_class(vllm_config)
189
190
191
192
193
194
195
196
197

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
198
            stat_loggers=stat_loggers,
199
200
        )

201
202
203
    def __del__(self):
        self.shutdown()

204
205
206
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

207
208
        shutdown_prometheus()

209
210
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
211
212
213
214
215
216
217
218
219
220
221

        if handler := getattr(self, "output_handler", None):
            handler.cancel()

    async def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
222
        tokenization_kwargs: Optional[dict[str, Any]] = None,
223
224
225
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
226
    ) -> RequestOutputCollector:
227
228
        """Add new request to the AsyncLLM."""

229
230
231
        if self.errored:
            raise EngineDeadError()

232
233
234
235
236
        assert isinstance(params, SamplingParams), \
            "Pooling is not supported in V1"

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
237

238
        # Convert Input --> Request.
239
240
        prompt_str, request = self.processor.process_inputs(
            request_id, prompt, params, arrival_time, lora_request,
241
242
            tokenization_kwargs, trace_headers, prompt_adapter_request,
            priority)
243

244
        if params.n == 1:
245
            await self._add_request(request, prompt_str, None, 0, queue)
246
247
248
249
            return queue

        # Fan out child requests (for n>1).
        parent_request = ParentRequest(request_id, params)
250
        for idx in range(params.n):
251
            request_id, params = parent_request.get_child_info(idx)
252
            child_request = request if idx == params.n - 1 else copy(request)
253
254
            child_request.request_id = request_id
            child_request.sampling_params = params
255
256
            await self._add_request(child_request, prompt_str, parent_request,
                                    idx, queue)
257
        return queue
258

259
    async def _add_request(self, request: EngineCoreRequest,
260
                           prompt: Optional[str],
261
                           parent_req: Optional[ParentRequest], index: int,
262
                           queue: RequestOutputCollector):
263

264
        # Add the request to OutputProcessor (this process).
265
266
        self.output_processor.add_request(request, prompt, parent_req, index,
                                          queue)
267

268
269
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
270

271
272
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
273
274
275
276
277
278

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
279
    async def generate(
280
281
282
283
284
285
286
287
288
289
290
291
        self,
        prompt: PromptType,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
292
            * 2) Processing the Input.
293
294
295
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

296
297
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
298
299
300
301
302
303
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

304
305
306
307
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
308
            self._run_output_handler()
309
310

            q = await self.add_request(
311
312
313
314
315
316
317
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
318
            )
319

320
321
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
322
323
            finished = False
            while not finished:
324
325
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
326
                out = q.get_nowait() or await q.get()
327

328
                # Note: both OutputProcessor and EngineCore handle their
329
                # own request cleanup based on finished.
330
                finished = out.finished
331
332
                yield out

333
334
        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
335
336
        except asyncio.CancelledError:
            await self.abort(request_id)
337
338
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
339
            raise
340

341
342
343
344
345
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
346

347
348
349
350
351
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
352

353
        # Unexpected error in the generate() task (possibly recoverable).
354
        except Exception as e:
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
        stat_loggers = self.stat_loggers if log_stats else None

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

                    iteration_stats = IterationStats() if (
                        log_stats and num_outputs) else None

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
                    if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
                        slices = (outputs.outputs, )
                    else:
                        slices = np.array_split(
                            outputs.outputs,
                            cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
                            outputs_slice, outputs.timestamp, iteration_stats)
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
                            processed_outputs.reqs_to_abort)

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
                    if stat_loggers:
                        AsyncLLM._record_stats(
                            stat_loggers[outputs.engine_index],
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
422
423

    async def abort(self, request_id: str) -> None:
424
        """Abort RequestId in OutputProcessor and EngineCore."""
425

426
        request_ids = self.output_processor.abort_requests((request_id, ))
427
428
        await self.engine_core.abort_requests_async(request_ids)

429
430
        if self.log_requests:
            logger.info("Aborted request %s.", request_id)
431

432
    @staticmethod
433
    def _record_stats(
434
        stat_loggers: list[StatLoggerBase],
435
        scheduler_stats: Optional[SchedulerStats],
436
        iteration_stats: Optional[IterationStats],
437
    ):
438
439
440
        """static so that it can be used from the output_handler task
        without a circular ref to AsyncLLM."""
        for stat_logger in stat_loggers:
441
442
            stat_logger.record(scheduler_stats=scheduler_stats,
                               iteration_stats=iteration_stats)
443

444
445
446
447
448
449
450
451
452
453
454
    def encode(
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
    ):
        raise ValueError("Not Supported on V1 yet.")

455
456
457
    async def get_vllm_config(self) -> VllmConfig:
        return self.vllm_config

458
459
460
461
462
463
    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def get_decoding_config(self):
        raise ValueError("Not Supported on V1 yet.")

464
465
466
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.processor.input_preprocessor

467
468
469
470
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
471
        return self.tokenizer.get_lora_tokenizer(lora_request)
472
473
474
475
476
477
478
479
480

    async def is_tracing_enabled(self) -> bool:
        return False

    async def do_log_stats(
        self,
        scheduler_outputs=None,
        model_output=None,
    ) -> None:
481
482
483
        for loggers in self.stat_loggers:
            for stat_logger in loggers:
                stat_logger.log()
484
485
486
487
488

    async def check_health(self) -> None:
        logger.debug("Called check_health.")

    async def start_profile(self) -> None:
489
        await self.engine_core.profile_async(True)
490
491

    async def stop_profile(self) -> None:
492
        await self.engine_core.profile_async(False)
493

494
495
496
497
498
    async def reset_mm_cache(self) -> None:
        self.processor.mm_registry.reset_processor_cache()
        self.processor.mm_input_cache_client.reset()
        await self.engine_core.reset_mm_cache_async()

499
500
501
502
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
503
504
        await self.engine_core.reset_prefix_cache_async()

505
506
507
    async def sleep(self, level: int = 1) -> None:
        await self.engine_core.sleep_async(level)

508
509
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        await self.engine_core.wake_up_async(tags)
510

511
512
513
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

514
    async def add_lora(self, lora_request: LoRARequest) -> bool:
515
        """Load a new LoRA adapter into the engine for future requests."""
516
517
518
519
520
521
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

522
    async def list_loras(self) -> set[int]:
523
524
525
526
527
528
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
529

530
531
532
533
534
535
536
537
538
539
540
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
            method, timeout, args, kwargs)

541
542
    @property
    def is_running(self) -> bool:
543
544
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
545
546
547

    @property
    def is_stopped(self) -> bool:
548
        return self.errored
549
550
551

    @property
    def errored(self) -> bool:
552
        return self.engine_core.resources.engine_dead or not self.is_running
553
554
555

    @property
    def dead_error(self) -> BaseException:
556
        return EngineDeadError()