async_llm.py 30.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
from collections.abc import AsyncGenerator, Iterable, Mapping
8
from copy import copy
9
from typing import Any
10

11
import numpy as np
12
import torch
13

14
import vllm.envs as envs
15
from vllm.config import VllmConfig
16
17
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
18
from vllm.entrypoints.utils import _validate_truncation_size
19
from vllm.inputs import PromptType
20
21
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
22
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
23
from vllm.outputs import PoolingRequestOutput, RequestOutput
24
from vllm.plugins.io_processors import get_io_processor
25
from vllm.pooling_params import PoolingParams
26
from vllm.sampling_params import SamplingParams
27
from vllm.tasks import SupportedTask
28
from vllm.tracing import init_tracer
29
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
30
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
31
from vllm.usage.usage_lib import UsageContext
32
from vllm.utils import Device, cdiv
33
34
35
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
from vllm.utils.func_utils import deprecate_kwargs
36
from vllm.v1.engine import EngineCoreRequest
37
from vllm.v1.engine.core_client import EngineCoreClient
38
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
39
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
40
from vllm.v1.engine.parallel_sampling import ParentRequest
41
from vllm.v1.engine.processor import Processor
42
from vllm.v1.executor import Executor
43
44
45
46
47
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
48
from vllm.v1.metrics.prometheus import shutdown_prometheus
49
from vllm.v1.metrics.stats import IterationStats
50
51
52
53
54
55
56
57

logger = init_logger(__name__)


class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
58
        executor_class: type[Executor],
59
60
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
61
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
62
63
64
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
65
        stat_loggers: list[StatLoggerFactory] | None = None,
66
        aggregate_engine_logging: bool = False,
67
        client_addresses: dict[str, str] | None = None,
68
        client_count: int = 1,
69
        client_index: int = 0,
70
    ) -> None:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
91
92
93
94
95
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
96
97
                "VLLM_USE_V1=0 or 1 and report this issue on Github."
            )
98

99
100
101
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

102
        self.model_config = vllm_config.model_config
103
        self.vllm_config = vllm_config
104
        self.observability_config = vllm_config.observability_config
105
        self.log_requests = log_requests
106

107
108
109
110
111
112
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
113
            logger.info(
114
115
116
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
117
            )
118

119
120
121
122
123
124
125
126
127
128
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = init_tokenizer_from_configs(self.model_config)

        self.processor = Processor(self.vllm_config, tokenizer)
        self.io_processor = get_io_processor(
            self.vllm_config,
            self.model_config.io_processor_plugin,
        )
129

130
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
131
132
133
        self.output_processor = OutputProcessor(
            self.tokenizer, log_stats=self.log_stats
        )
134
135
        if self.observability_config.otlp_traces_endpoint is not None:
            tracer = init_tracer(
136
137
                "vllm.llm_engine", self.observability_config.otlp_traces_endpoint
            )
138
            self.output_processor.tracer = tracer
139
140

        # EngineCore (starts the engine in background process).
141
        self.engine_core = EngineCoreClient.make_async_mp_client(
142
143
            vllm_config=vllm_config,
            executor_class=executor_class,
144
            log_stats=self.log_stats,
145
            client_addresses=client_addresses,
146
            client_count=client_count,
147
            client_index=client_index,
148
        )
149
150

        # Loggers.
151
        self.logger_manager: StatLoggerManager | None = None
152
153
154
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
155
                engine_idxs=self.engine_core.engine_ranks_managed,
156
                custom_stat_loggers=custom_stat_loggers,
157
                enable_default_loggers=log_stats,
158
                client_count=client_count,
159
                aggregate_engine_logging=aggregate_engine_logging,
160
161
162
            )
            self.logger_manager.log_engine_initialized()

163
        self.output_handler: asyncio.Task | None = None
164
165
166
167
168
169
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
170

171
172
173
        if envs.VLLM_TORCH_PROFILER_DIR:
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
174
175
                envs.VLLM_TORCH_PROFILER_DIR,
            )
176
177
178
179
180
181
182
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
183
184
185
                    envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True
                ),
            )
186
187
188
        else:
            self.profiler = None

189
    @classmethod
190
191
    @deprecate_kwargs(
        "disable_log_requests",
192
193
194
        additional_message=(
            "This argument will have no effect. Use `enable_log_requests` instead."
        ),
195
    )
196
    def from_vllm_config(
197
198
199
200
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
201
        stat_loggers: list[StatLoggerFactory] | None = None,
202
        enable_log_requests: bool = False,
203
        aggregate_engine_logging: bool = False,
204
        disable_log_stats: bool = False,
205
        client_addresses: dict[str, str] | None = None,
206
207
208
        client_count: int = 1,
        client_index: int = 0,
        disable_log_requests: bool = True,  # Deprecated, will be removed
209
210
211
212
213
214
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
215
216
                "VLLM_USE_V1=0 or 1 and report this issue on Github."
            )
217
218
219
220
221
222

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
223
            stat_loggers=stat_loggers,
224
            log_requests=enable_log_requests,
225
            log_stats=not disable_log_stats,
226
            aggregate_engine_logging=aggregate_engine_logging,
227
            usage_context=usage_context,
228
            client_addresses=client_addresses,
229
            client_count=client_count,
230
            client_index=client_index,
231
232
        )

233
234
235
236
237
238
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
239
        stat_loggers: list[StatLoggerFactory] | None = None,
240
    ) -> "AsyncLLM":
241
242
243
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
244
        vllm_config = engine_args.create_engine_config(usage_context)
245
        executor_class = Executor.get_class(vllm_config)
246
247
248
249
250

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
251
            log_requests=engine_args.enable_log_requests,
252
253
254
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
255
            stat_loggers=stat_loggers,
256
257
        )

258
259
260
    def __del__(self):
        self.shutdown()

261
262
263
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

264
265
        shutdown_prometheus()

266
267
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
268

269
        cancel_task_threadsafe(getattr(self, "output_handler", None))
270

271
272
273
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

274
275
276
    async def add_request(
        self,
        request_id: str,
277
278
279
280
281
282
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
283
        priority: int = 0,
284
285
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
286
    ) -> RequestOutputCollector:
287
288
        """Add new request to the AsyncLLM."""

289
290
291
        if self.errored:
            raise EngineDeadError()

292
        is_pooling = isinstance(params, PoolingParams)
293
294
295

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
296

297
        # Convert Input --> Request.
298
299
300
301
302
303
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
        else:
            assert prompt_text is None
            logger.warning_once(
                "Processor has been moved under OpenAIServing and will "
304
305
306
307
308
309
310
311
312
313
314
315
316
317
                "be removed from AsyncLLM in v0.13."
            )
            request = self.processor.process_inputs(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )
            prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
318

319
        if is_pooling or params.n == 1:
320
            await self._add_request(request, prompt_text, None, 0, queue)
321
322
            return queue

323
324
325
326
327
        # Get the updated SamplingParams from the request, which
        # were cloned/updated in processor.process_inputs above.
        parent_params = request.sampling_params
        assert parent_params is not None

328
        # Fan out child requests (for n>1).
329
330
331
        parent_request = ParentRequest(request_id, parent_params)
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
332
            child_request = request if idx == parent_params.n - 1 else copy(request)
333
            child_request.request_id = request_id
334
            child_request.sampling_params = child_params
335
336
337
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
338
        return queue
339

340
341
342
    async def _add_request(
        self,
        request: EngineCoreRequest,
343
344
        prompt: str | None,
        parent_req: ParentRequest | None,
345
346
347
        index: int,
        queue: RequestOutputCollector,
    ):
348
        # Add the request to OutputProcessor (this process).
349
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
350

351
352
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
353

354
355
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
356
357
358
359
360
361

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
362
    async def generate(
363
        self,
364
        prompt: EngineCoreRequest | PromptType,
365
366
        sampling_params: SamplingParams,
        request_id: str,
367
        *,
368
369
370
371
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
372
        priority: int = 0,
373
        data_parallel_rank: int | None = None,
374
375
376
377
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
378
            * 2) Processing the Input.
379
380
381
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

382
383
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
384
385
386
387
388
389
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

390
391
392
393
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and sampling_params.prompt_logprobs
        ):
394
395
396
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
397
398
                "prompt logprobs"
            )
399

400
401
402
403
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
404
            self._run_output_handler()
405

406
407
408
409
410
411
412
413
414
415
            if tokenization_kwargs is None:
                tokenization_kwargs = {}
                truncate_prompt_tokens = sampling_params.truncate_prompt_tokens

                _validate_truncation_size(
                    self.model_config.max_model_len,
                    truncate_prompt_tokens,
                    tokenization_kwargs,
                )

416
417
418
419
420
421
422
423
424
425
426
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
427

428
429
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
430
431
            finished = False
            while not finished:
432
433
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
434
                out = q.get_nowait() or await q.get()
435

436
                # Note: both OutputProcessor and EngineCore handle their
437
                # own request cleanup based on finished.
438
                finished = out.finished
439
440
                yield out

441
        # If the request is disconnected by the client, generate()
442
443
444
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
445
            await self.abort(request_id)
446
447
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
448
            raise
449

450
451
452
453
454
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
455

456
457
458
459
460
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
461

462
        # Unexpected error in the generate() task (possibly recoverable).
463
        except Exception as e:
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
480
        logger_manager = self.logger_manager
481
        processor = self.processor
482
483
484
485
486
487
488
489

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

490
491
492
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
493
494
495
496

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
497
                    if num_outputs <= envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
498
                        slices = (outputs.outputs,)
499
500
501
                    else:
                        slices = np.array_split(
                            outputs.outputs,
502
                            cdiv(num_outputs, envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE),
503
                        )
504
505
506
507

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
508
509
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
510
511
512
513
514
515
516
517
518
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
519
520
                            processed_outputs.reqs_to_abort
                        )
521
522
523
524

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
525
526
527
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
528
529
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
530
                            mm_cache_stats=processor.stat_mm_cache(),
531
532
533
534
535
536
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
537

538
    async def abort(self, request_id: str | Iterable[str]) -> None:
539
        """Abort RequestId in OutputProcessor and EngineCore."""
540

541
542
543
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
544
545
        all_request_ids = self.output_processor.abort_requests(request_ids)
        await self.engine_core.abort_requests_async(all_request_ids)
546

547
        if self.log_requests:
548
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
549

550
    async def encode(
551
552
553
554
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
555
556
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
557
        priority: int = 0,
558
559
        truncate_prompt_tokens: int | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            self._run_output_handler()

581
            if tokenization_kwargs is None:
582
                tokenization_kwargs = {}
583
584
585
586
587
588
            _validate_truncation_size(
                self.model_config.max_model_len,
                truncate_prompt_tokens,
                tokenization_kwargs,
            )

589
590
591
592
593
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
594
                tokenization_kwargs=tokenization_kwargs,
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
638

639
    @property
640
    def tokenizer(self) -> AnyTokenizer | None:
641
        return self.processor.tokenizer
642

643
    @tokenizer.setter
644
    def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
645
        self.processor.tokenizer = tokenizer
646

647
    async def get_tokenizer(self) -> AnyTokenizer:
648
        if self.tokenizer is None:
649
650
651
            raise ValueError(
                "Unable to get tokenizer because skip_tokenizer_init is True"
            )
652

653
        return self.tokenizer
654
655

    async def is_tracing_enabled(self) -> bool:
656
        return self.observability_config.otlp_traces_endpoint is not None
657

658
    async def do_log_stats(self) -> None:
659
660
        if self.logger_manager:
            self.logger_manager.log()
661
662
663

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
664
665
        if self.errored:
            raise self.dead_error
666
667

    async def start_profile(self) -> None:
668
669
670
671
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
672
673

    async def stop_profile(self) -> None:
674
675
676
677
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
678

679
    async def reset_mm_cache(self) -> None:
680
        self.processor.clear_mm_cache()
681
682
        await self.engine_core.reset_mm_cache_async()

683
    async def reset_prefix_cache(self, device: Device | None = None) -> None:
684
685
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
686
687
        await self.engine_core.reset_prefix_cache_async()

688
    async def sleep(self, level: int = 1) -> None:
689
        await self.reset_prefix_cache()
690
691
        await self.engine_core.sleep_async(level)

692
    async def wake_up(self, tags: list[str] | None = None) -> None:
693
        await self.engine_core.wake_up_async(tags)
694

695
696
697
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

698
    async def add_lora(self, lora_request: LoRARequest) -> bool:
699
        """Load a new LoRA adapter into the engine for future requests."""
700
701
702
703
704
705
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

706
    async def list_loras(self) -> set[int]:
707
708
709
710
711
712
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
713

714
715
716
    async def collective_rpc(
        self,
        method: str,
717
        timeout: float | None = None,
718
        args: tuple = (),
719
        kwargs: dict | None = None,
720
    ):
721
722
723
724
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
725
726
            method, timeout, args, kwargs
        )
727

728
729
730
731
732
733
734
735
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

736
            logger.info("Engines are still running, waiting for requests to drain...")
737
738
            await asyncio.sleep(1)  # Wait 1 second before checking again

739
740
741
742
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
743

744
745
746
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
747
748
749
750
751
752
753
754
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
755
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
756
        if old_data_parallel_size == new_data_parallel_size:
757
758
759
760
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
761
762
            return
        logger.info(
763
764
765
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
766
767
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
768
769
770
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
771
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
772
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
773
774

        # recreate stat loggers
775
776
777
778
779
780
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
781
                vllm_config=self.vllm_config,
782
                engine_idxs=list(range(new_data_parallel_size)),
783
784
785
                custom_stat_loggers=None,
            )

786
787
    @property
    def is_running(self) -> bool:
788
789
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
790
791
792

    @property
    def is_stopped(self) -> bool:
793
        return self.errored
794
795
796

    @property
    def errored(self) -> bool:
797
        return self.engine_core.resources.engine_dead or not self.is_running
798
799
800

    @property
    def dead_error(self) -> BaseException:
801
        return EngineDeadError()