async_llm.py 20.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import asyncio
3
from collections.abc import AsyncGenerator, Mapping
4
from copy import copy
5
from typing import Any, Optional, Union
6

7
8
import numpy as np

9
import vllm.envs as envs
10
11
12
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
13
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
14
from vllm.inputs import PromptType
15
from vllm.inputs.preprocess import InputPreprocessor
16
17
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
18
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
19
from vllm.outputs import RequestOutput
20
21
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
22
from vllm.sampling_params import SamplingParams
23
24
25
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
26
from vllm.utils import Device, cdiv
27
from vllm.v1.engine import EngineCoreRequest
28
29
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
30
31
from vllm.v1.engine.output_processor import (OutputProcessor,
                                             RequestOutputCollector)
32
from vllm.v1.engine.parallel_sampling import ParentRequest
33
from vllm.v1.engine.processor import Processor
34
from vllm.v1.executor.abstract import Executor
35
36
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
                                     setup_default_loggers)
37
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
38
39
40
41
42
43
44
45
46

logger = init_logger(__name__)


class AsyncLLM(EngineClient):

    def __init__(
        self,
        vllm_config: VllmConfig,
47
        executor_class: type[Executor],
48
49
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
50
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
51
52
53
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
54
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
55
    ) -> None:
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
76
77
78
79
80
81
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
82

83
        self.model_config = vllm_config.model_config
84
        self.vllm_config = vllm_config
85
86
        self.log_requests = log_requests
        self.log_stats = log_stats
87
88

        # Set up stat loggers; independent set for each DP rank.
89
90
91
92
93
94
        self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
            vllm_config=vllm_config,
            log_stats=self.log_stats,
            engine_num=vllm_config.parallel_config.data_parallel_size,
            custom_stat_loggers=stat_loggers,
        )
95
96
97
98
99

        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
100
            lora_config=vllm_config.lora_config)
101
102

        # Processor (converts Inputs --> EngineCoreRequests).
103
        self.processor = Processor(
104
            vllm_config=vllm_config,
105
            tokenizer=self.tokenizer,
106
            mm_registry=mm_registry,
107
        )
108

109
110
111
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
                                                log_stats=self.log_stats)
112
113

        # EngineCore (starts the engine in background process).
114
115
116
117
118
        core_client_class = AsyncMPClient if (
            vllm_config.parallel_config.data_parallel_size
            == 1) else DPAsyncMPClient

        self.engine_core = core_client_class(
119
120
            vllm_config=vllm_config,
            executor_class=executor_class,
121
            log_stats=self.log_stats,
122
        )
123
124
125
        if self.stat_loggers:
            for stat_logger in self.stat_loggers[0]:
                stat_logger.log_engine_initialized()
126
        self.output_handler: Optional[asyncio.Task] = None
127
128
129
130
131
132
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
133

134
135
136
137
138
139
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
140
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        disable_log_requests: bool = False,
        disable_log_stats: bool = False,
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
156
            stat_loggers=stat_loggers,
157
158
159
160
161
            log_requests=not disable_log_requests,
            log_stats=not disable_log_stats,
            usage_context=usage_context,
        )

162
163
164
165
166
167
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
168
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
169
    ) -> "AsyncLLM":
170
171
172
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
173
        vllm_config = engine_args.create_engine_config(usage_context)
174
        executor_class = Executor.get_class(vllm_config)
175
176
177
178
179
180
181
182
183

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
184
            stat_loggers=stat_loggers,
185
186
        )

187
188
189
    def __del__(self):
        self.shutdown()

190
191
192
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

193
194
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
195
196
197
198
199
200
201
202
203
204
205

        if handler := getattr(self, "output_handler", None):
            handler.cancel()

    async def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
206
        tokenization_kwargs: Optional[dict[str, Any]] = None,
207
208
209
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
210
    ) -> RequestOutputCollector:
211
212
        """Add new request to the AsyncLLM."""

213
214
215
        if self.errored:
            raise EngineDeadError()

216
217
218
219
220
        assert isinstance(params, SamplingParams), \
            "Pooling is not supported in V1"

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
221

222
        # Convert Input --> Request.
223
224
        prompt_str, request = self.processor.process_inputs(
            request_id, prompt, params, arrival_time, lora_request,
225
226
            tokenization_kwargs, trace_headers, prompt_adapter_request,
            priority)
227

228
        if params.n == 1:
229
            await self._add_request(request, prompt_str, None, 0, queue)
230
231
232
233
            return queue

        # Fan out child requests (for n>1).
        parent_request = ParentRequest(request_id, params)
234
        for idx in range(params.n):
235
            request_id, params = parent_request.get_child_info(idx)
236
            child_request = request if idx == params.n - 1 else copy(request)
237
238
            child_request.request_id = request_id
            child_request.sampling_params = params
239
240
            await self._add_request(child_request, prompt_str, parent_request,
                                    idx, queue)
241
        return queue
242

243
    async def _add_request(self, request: EngineCoreRequest,
244
                           prompt: Optional[str],
245
                           parent_req: Optional[ParentRequest], index: int,
246
                           queue: RequestOutputCollector):
247

248
        # Add the request to OutputProcessor (this process).
249
250
        self.output_processor.add_request(request, prompt, parent_req, index,
                                          queue)
251

252
253
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
254

255
256
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
257
258
259
260
261
262

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
263
    async def generate(
264
265
266
267
268
269
270
271
272
273
274
275
        self,
        prompt: PromptType,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
276
            * 2) Processing the Input.
277
278
279
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

280
281
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
282
283
284
285
286
287
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

288
289
290
291
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
292
            self._run_output_handler()
293
294

            q = await self.add_request(
295
296
297
298
299
300
301
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
302
            )
303

304
305
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
306
307
            finished = False
            while not finished:
308
309
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
310
                out = q.get_nowait() or await q.get()
311

312
                # Note: both OutputProcessor and EngineCore handle their
313
                # own request cleanup based on finished.
314
                finished = out.finished
315
316
                yield out

317
318
        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
319
320
        except asyncio.CancelledError:
            await self.abort(request_id)
321
322
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
323
            raise
324

325
326
327
328
329
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
330

331
332
333
334
335
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
336

337
        # Unexpected error in the generate() task (possibly recoverable).
338
        except Exception as e:
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
        stat_loggers = self.stat_loggers if log_stats else None

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

                    iteration_stats = IterationStats() if (
                        log_stats and num_outputs) else None

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
                    if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
                        slices = (outputs.outputs, )
                    else:
                        slices = np.array_split(
                            outputs.outputs,
                            cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
                            outputs_slice, outputs.timestamp, iteration_stats)
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
                            processed_outputs.reqs_to_abort)

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
                    if stat_loggers:
                        assert outputs.scheduler_stats is not None
                        AsyncLLM._record_stats(
                            stat_loggers[outputs.engine_index],
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
407
408

    async def abort(self, request_id: str) -> None:
409
        """Abort RequestId in OutputProcessor and EngineCore."""
410

411
        request_ids = self.output_processor.abort_requests((request_id, ))
412
413
        await self.engine_core.abort_requests_async(request_ids)

414
415
        if self.log_requests:
            logger.info("Aborted request %s.", request_id)
416

417
    @staticmethod
418
    def _record_stats(
419
420
        stat_loggers: list[StatLoggerBase],
        scheduler_stats: SchedulerStats,
421
        iteration_stats: Optional[IterationStats],
422
    ):
423
424
425
        """static so that it can be used from the output_handler task
        without a circular ref to AsyncLLM."""
        for stat_logger in stat_loggers:
426
427
            stat_logger.record(scheduler_stats=scheduler_stats,
                               iteration_stats=iteration_stats)
428

429
430
431
432
433
434
435
436
437
438
439
    def encode(
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
    ):
        raise ValueError("Not Supported on V1 yet.")

440
441
442
    async def get_vllm_config(self) -> VllmConfig:
        return self.vllm_config

443
444
445
446
447
448
    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def get_decoding_config(self):
        raise ValueError("Not Supported on V1 yet.")

449
450
451
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.processor.input_preprocessor

452
453
454
455
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
456
        return self.tokenizer.get_lora_tokenizer(lora_request)
457
458
459
460
461
462
463
464
465

    async def is_tracing_enabled(self) -> bool:
        return False

    async def do_log_stats(
        self,
        scheduler_outputs=None,
        model_output=None,
    ) -> None:
466
467
468
        for loggers in self.stat_loggers:
            for stat_logger in loggers:
                stat_logger.log()
469
470
471
472
473

    async def check_health(self) -> None:
        logger.debug("Called check_health.")

    async def start_profile(self) -> None:
474
        await self.engine_core.profile_async(True)
475
476

    async def stop_profile(self) -> None:
477
        await self.engine_core.profile_async(False)
478

479
480
481
482
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
483
484
        await self.engine_core.reset_prefix_cache_async()

485
486
487
    async def sleep(self, level: int = 1) -> None:
        await self.engine_core.sleep_async(level)

488
489
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        await self.engine_core.wake_up_async(tags)
490

491
492
493
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

494
    async def add_lora(self, lora_request: LoRARequest) -> bool:
495
        """Load a new LoRA adapter into the engine for future requests."""
496
497
498
499
500
501
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

502
    async def list_loras(self) -> set[int]:
503
504
505
506
507
508
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
509

510
511
512
513
514
515
516
517
518
519
520
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
            method, timeout, args, kwargs)

521
522
    @property
    def is_running(self) -> bool:
523
524
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
525
526
527

    @property
    def is_stopped(self) -> bool:
528
        return self.errored
529
530
531

    @property
    def errored(self) -> bool:
532
        return self.engine_core.resources.engine_dead or not self.is_running
533
534
535

    @property
    def dead_error(self) -> BaseException:
536
        return EngineDeadError()