async_llm.py 20 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import asyncio
3
import logging
4
from collections.abc import AsyncGenerator, Mapping
5
from copy import copy
6
from typing import Optional, Union
7

8
9
import numpy as np

10
import vllm.envs as envs
11
12
13
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
14
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
15
from vllm.inputs import PromptType
16
from vllm.inputs.preprocess import InputPreprocessor
17
18
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
19
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
20
from vllm.outputs import RequestOutput
21
22
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
23
from vllm.sampling_params import SamplingParams
24
25
26
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
27
from vllm.utils import Device, cdiv
28
from vllm.v1.engine import EngineCoreRequest
29
30
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
31
32
from vllm.v1.engine.output_processor import (OutputProcessor,
                                             RequestOutputCollector)
33
from vllm.v1.engine.parallel_sampling import ParentRequest
34
from vllm.v1.engine.processor import Processor
35
from vllm.v1.executor.abstract import Executor
36
37
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
                                     StatLoggerBase)
38
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
39
40
41
42
43
44
45
46
47

logger = init_logger(__name__)


class AsyncLLM(EngineClient):

    def __init__(
        self,
        vllm_config: VllmConfig,
48
        executor_class: type[Executor],
49
50
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
51
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
52
53
54
55
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
    ) -> None:
56
57
58
59
60
61
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
62

63
        self.model_config = vllm_config.model_config
64
        self.vllm_config = vllm_config
65
66
        self.log_requests = log_requests
        self.log_stats = log_stats
67
68
69

        # Set up stat loggers; independent set for each DP rank.
        self.stat_loggers: list[list[StatLoggerBase]] = []
70
        if self.log_stats:
71
72
73
74
75
76
77
            for i in range(vllm_config.parallel_config.data_parallel_size):
                loggers: list[StatLoggerBase] = []
                if logger.isEnabledFor(logging.INFO):
                    loggers.append(LoggingStatLogger(engine_index=i))
                loggers.append(
                    PrometheusStatLogger(vllm_config, engine_index=i))
                self.stat_loggers.append(loggers)
78
79
80
81
82

        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
83
            lora_config=vllm_config.lora_config)
84
85

        # Processor (converts Inputs --> EngineCoreRequests).
86
        self.processor = Processor(
87
            vllm_config=vllm_config,
88
            tokenizer=self.tokenizer,
89
            mm_registry=mm_registry,
90
        )
91

92
93
94
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
                                                log_stats=self.log_stats)
95
96

        # EngineCore (starts the engine in background process).
97
98
99
100
101
        core_client_class = AsyncMPClient if (
            vllm_config.parallel_config.data_parallel_size
            == 1) else DPAsyncMPClient

        self.engine_core = core_client_class(
102
103
            vllm_config=vllm_config,
            executor_class=executor_class,
104
            log_stats=self.log_stats,
105
106
        )

107
        self.output_handler: Optional[asyncio.Task] = None
108
109
110
111
112
113
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
        disable_log_requests: bool = False,
        disable_log_stats: bool = False,
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

        # FIXME(rob): refactor VllmConfig to include the StatLoggers
        # include StatLogger in the Oracle decision.
        if stat_loggers is not None:
            raise ValueError("Custom StatLoggers are not yet supported on V1. "
                             "Explicitly set VLLM_USE_V1=0 to disable V1.")

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
            log_requests=not disable_log_requests,
            log_stats=not disable_log_stats,
            usage_context=usage_context,
        )

148
149
150
151
152
153
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
154
    ) -> "AsyncLLM":
155
156
157
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
158
        vllm_config = engine_args.create_engine_config(usage_context)
159
        executor_class = Executor.get_class(vllm_config)
160
161
162
163
164
165
166
167
168
169
170

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
        )

171
172
173
    def __del__(self):
        self.shutdown()

174
175
176
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

177
178
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
179
180
181
182
183
184
185
186
187
188
189
190
191
192

        if handler := getattr(self, "output_handler", None):
            handler.cancel()

    async def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
193
    ) -> RequestOutputCollector:
194
195
        """Add new request to the AsyncLLM."""

196
197
198
        if self.errored:
            raise EngineDeadError()

199
200
201
202
203
        assert isinstance(params, SamplingParams), \
            "Pooling is not supported in V1"

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
204

205
206
207
208
209
210
211
        # Convert Input --> Request.
        request = self.processor.process_inputs(request_id, prompt, params,
                                                arrival_time, lora_request,
                                                trace_headers,
                                                prompt_adapter_request,
                                                priority)

212
        if params.n == 1:
213
214
215
216
217
            await self._add_request(request, None, 0, queue)
            return queue

        # Fan out child requests (for n>1).
        parent_request = ParentRequest(request_id, params)
218
        for idx in range(params.n):
219
            request_id, params = parent_request.get_child_info(idx)
220
            child_request = request if idx == params.n - 1 else copy(request)
221
222
223
224
            child_request.request_id = request_id
            child_request.sampling_params = params
            await self._add_request(child_request, parent_request, idx, queue)
        return queue
225

226
227
    async def _add_request(self, request: EngineCoreRequest,
                           parent_req: Optional[ParentRequest], index: int,
228
                           queue: RequestOutputCollector):
229

230
231
        # Add the request to OutputProcessor (this process).
        self.output_processor.add_request(request, parent_req, index, queue)
232

233
234
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
235

236
237
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
238
239
240
241
242
243

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
244
    async def generate(
245
246
247
248
249
250
251
252
253
254
255
256
        self,
        prompt: PromptType,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
257
            * 2) Processing the Input.
258
259
260
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

261
262
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
263
264
265
266
267
268
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

269
270
271
272
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
273
            self._run_output_handler()
274
275

            q = await self.add_request(
276
277
278
279
280
281
282
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
283
            )
284

285
286
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
287
288
            finished = False
            while not finished:
289
290
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
291
                out = q.get_nowait() or await q.get()
292

293
                # Note: both OutputProcessor and EngineCore handle their
294
                # own request cleanup based on finished.
295
                finished = out.finished
296
297
                yield out

298
299
        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
300
301
        except asyncio.CancelledError:
            await self.abort(request_id)
302
303
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
304
            raise
305

306
307
308
309
310
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
311

312
313
314
315
316
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
317

318
        # Unexpected error in the generate() task (possibly recoverable).
319
        except Exception as e:
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
        stat_loggers = self.stat_loggers if log_stats else None

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

                    iteration_stats = IterationStats() if (
                        log_stats and num_outputs) else None

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
                    if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
                        slices = (outputs.outputs, )
                    else:
                        slices = np.array_split(
                            outputs.outputs,
                            cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
                            outputs_slice, outputs.timestamp, iteration_stats)
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
                            processed_outputs.reqs_to_abort)

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
                    if stat_loggers:
                        assert outputs.scheduler_stats is not None
                        AsyncLLM._record_stats(
                            stat_loggers[outputs.engine_index],
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
388
389

    async def abort(self, request_id: str) -> None:
390
        """Abort RequestId in OutputProcessor and EngineCore."""
391

392
        request_ids = self.output_processor.abort_requests((request_id, ))
393
394
        await self.engine_core.abort_requests_async(request_ids)

395
396
        if self.log_requests:
            logger.info("Aborted request %s.", request_id)
397

398
    @staticmethod
399
    def _record_stats(
400
401
        stat_loggers: list[StatLoggerBase],
        scheduler_stats: SchedulerStats,
402
        iteration_stats: Optional[IterationStats],
403
    ):
404
405
406
        """static so that it can be used from the output_handler task
        without a circular ref to AsyncLLM."""
        for stat_logger in stat_loggers:
407
408
            stat_logger.record(scheduler_stats=scheduler_stats,
                               iteration_stats=iteration_stats)
409

410
411
412
413
414
415
416
417
418
419
420
    def encode(
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
    ):
        raise ValueError("Not Supported on V1 yet.")

421
422
423
    async def get_vllm_config(self) -> VllmConfig:
        return self.vllm_config

424
425
426
427
428
429
    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def get_decoding_config(self):
        raise ValueError("Not Supported on V1 yet.")

430
431
432
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.processor.input_preprocessor

433
434
435
436
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
437
        return self.tokenizer.get_lora_tokenizer(lora_request)
438
439
440
441
442
443
444
445
446

    async def is_tracing_enabled(self) -> bool:
        return False

    async def do_log_stats(
        self,
        scheduler_outputs=None,
        model_output=None,
    ) -> None:
447
448
449
        for loggers in self.stat_loggers:
            for stat_logger in loggers:
                stat_logger.log()
450
451
452
453
454

    async def check_health(self) -> None:
        logger.debug("Called check_health.")

    async def start_profile(self) -> None:
455
        await self.engine_core.profile_async(True)
456
457

    async def stop_profile(self) -> None:
458
        await self.engine_core.profile_async(False)
459

460
461
462
463
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
464
465
        await self.engine_core.reset_prefix_cache_async()

466
467
468
    async def sleep(self, level: int = 1) -> None:
        await self.engine_core.sleep_async(level)

469
470
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        await self.engine_core.wake_up_async(tags)
471

472
473
474
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

475
    async def add_lora(self, lora_request: LoRARequest) -> bool:
476
        """Load a new LoRA adapter into the engine for future requests."""
477
478
479
480
481
482
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

483
    async def list_loras(self) -> set[int]:
484
485
486
487
488
489
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
490

491
492
493
494
495
496
497
498
499
500
501
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
            method, timeout, args, kwargs)

502
503
    @property
    def is_running(self) -> bool:
504
505
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
506
507
508

    @property
    def is_stopped(self) -> bool:
509
        return self.errored
510
511
512

    @property
    def errored(self) -> bool:
513
        return self.engine_core.resources.engine_dead or not self.is_running
514
515
516

    @property
    def dead_error(self) -> BaseException:
517
        return EngineDeadError()