"fern/pages/kubernetes/installation-guide.md" did not exist on "95dd9426d8117e0d0ea07492744ab8261a399725"
async_llm.py 31.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
5
import os
import socket
6
import time
7
from collections.abc import AsyncGenerator, Iterable, Mapping
8
from copy import copy
9
from typing import Any, cast
10

11
import numpy as np
12
import torch
13

14
import vllm.envs as envs
15
from vllm.config import VllmConfig
16
from vllm.engine.arg_utils import AsyncEngineArgs
17
from vllm.engine.protocol import Device, EngineClient
18
from vllm.entrypoints.utils import _validate_truncation_size
19
from vllm.inputs import PromptType
20
21
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
22
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
23
from vllm.outputs import PoolingRequestOutput, RequestOutput
24
from vllm.plugins.io_processors import get_io_processor
25
from vllm.pooling_params import PoolingParams
26
from vllm.sampling_params import SamplingParams
27
from vllm.tasks import SupportedTask
28
from vllm.tracing import init_tracer
29
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
30
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
31
from vllm.usage.usage_lib import UsageContext
32
33
34
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
from vllm.utils.func_utils import deprecate_kwargs
35
from vllm.utils.math_utils import cdiv
36
from vllm.v1.engine import EngineCoreRequest
37
from vllm.v1.engine.core_client import EngineCoreClient
38
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
39
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
40
from vllm.v1.engine.parallel_sampling import ParentRequest
41
from vllm.v1.engine.processor import Processor
42
from vllm.v1.executor import Executor
43
44
45
46
47
from vllm.v1.metrics.loggers import (
    StatLoggerFactory,
    StatLoggerManager,
    load_stat_logger_plugin_factories,
)
48
from vllm.v1.metrics.prometheus import shutdown_prometheus
49
from vllm.v1.metrics.stats import IterationStats
50
51
52
53
54
55
56
57

logger = init_logger(__name__)


class AsyncLLM(EngineClient):
    def __init__(
        self,
        vllm_config: VllmConfig,
58
        executor_class: type[Executor],
59
60
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
61
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
62
63
64
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
65
        stat_loggers: list[StatLoggerFactory] | None = None,
66
        aggregate_engine_logging: bool = False,
67
        client_addresses: dict[str, str] | None = None,
68
        client_count: int = 1,
69
        client_index: int = 0,
70
    ) -> None:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
91
92
93
94
95
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
96
97
                "VLLM_USE_V1=0 or 1 and report this issue on Github."
            )
98

99
100
101
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

102
        self.model_config = vllm_config.model_config
103
        self.vllm_config = vllm_config
104
        self.observability_config = vllm_config.observability_config
105
        self.log_requests = log_requests
106

107
108
109
110
111
112
        custom_stat_loggers = list(stat_loggers or [])
        custom_stat_loggers.extend(load_stat_logger_plugin_factories())

        has_custom_loggers = bool(custom_stat_loggers)
        self.log_stats = log_stats or has_custom_loggers
        if not log_stats and has_custom_loggers:
113
            logger.info(
114
115
116
                "AsyncLLM created with log_stats=False, "
                "but custom stat loggers were found; "
                "enabling logging without default stat loggers."
117
            )
118

119
120
121
122
123
124
125
126
127
128
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = init_tokenizer_from_configs(self.model_config)

        self.processor = Processor(self.vllm_config, tokenizer)
        self.io_processor = get_io_processor(
            self.vllm_config,
            self.model_config.io_processor_plugin,
        )
129

130
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
131
132
133
        self.output_processor = OutputProcessor(
            self.tokenizer, log_stats=self.log_stats
        )
134
135
136
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
137
            self.output_processor.tracer = tracer
138
139

        # EngineCore (starts the engine in background process).
140
        self.engine_core = EngineCoreClient.make_async_mp_client(
141
142
            vllm_config=vllm_config,
            executor_class=executor_class,
143
            log_stats=self.log_stats,
144
            client_addresses=client_addresses,
145
            client_count=client_count,
146
            client_index=client_index,
147
        )
148
149

        # Loggers.
150
        self.logger_manager: StatLoggerManager | None = None
151
152
153
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
154
                engine_idxs=self.engine_core.engine_ranks_managed,
155
                custom_stat_loggers=custom_stat_loggers,
156
                enable_default_loggers=log_stats,
157
                client_count=client_count,
158
                aggregate_engine_logging=aggregate_engine_logging,
159
160
161
            )
            self.logger_manager.log_engine_initialized()

162
        self.output_handler: asyncio.Task | None = None
163
164
165
166
167
168
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
169

170
171
172
        if envs.VLLM_TORCH_PROFILER_DIR:
            logger.info(
                "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s",  # noqa: E501
173
174
                envs.VLLM_TORCH_PROFILER_DIR,
            )
175
176
177
178
179
180
181
            worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                ],
                with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
182
183
184
                    envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True
                ),
            )
185
186
187
        else:
            self.profiler = None

188
    @classmethod
189
190
    @deprecate_kwargs(
        "disable_log_requests",
191
192
193
        additional_message=(
            "This argument will have no effect. Use `enable_log_requests` instead."
        ),
194
    )
195
    def from_vllm_config(
196
197
198
199
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
200
        stat_loggers: list[StatLoggerFactory] | None = None,
201
        enable_log_requests: bool = False,
202
        aggregate_engine_logging: bool = False,
203
        disable_log_stats: bool = False,
204
        client_addresses: dict[str, str] | None = None,
205
206
207
        client_count: int = 1,
        client_index: int = 0,
        disable_log_requests: bool = True,  # Deprecated, will be removed
208
209
210
211
212
213
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
214
215
                "VLLM_USE_V1=0 or 1 and report this issue on Github."
            )
216
217
218
219
220
221

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
222
            stat_loggers=stat_loggers,
223
            log_requests=enable_log_requests,
224
            log_stats=not disable_log_stats,
225
            aggregate_engine_logging=aggregate_engine_logging,
226
            usage_context=usage_context,
227
            client_addresses=client_addresses,
228
            client_count=client_count,
229
            client_index=client_index,
230
231
        )

232
233
234
235
236
237
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
238
        stat_loggers: list[StatLoggerFactory] | None = None,
239
    ) -> "AsyncLLM":
240
241
242
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
243
        vllm_config = engine_args.create_engine_config(usage_context)
244
        executor_class = Executor.get_class(vllm_config)
245
246
247
248
249

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
250
            log_requests=engine_args.enable_log_requests,
251
252
253
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
254
            stat_loggers=stat_loggers,
255
256
        )

257
258
259
    def __del__(self):
        self.shutdown()

260
261
262
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

263
264
        shutdown_prometheus()

265
266
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
267

268
269
270
        handler = getattr(self, "output_handler", None)
        if handler is not None:
            cancel_task_threadsafe(handler)
271

272
273
274
    async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return await self.engine_core.get_supported_tasks_async()

275
276
277
    async def add_request(
        self,
        request_id: str,
278
279
280
281
282
283
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
284
        priority: int = 0,
285
286
        data_parallel_rank: int | None = None,
        prompt_text: str | None = None,
287
    ) -> RequestOutputCollector:
288
289
        """Add new request to the AsyncLLM."""

290
291
292
        if self.errored:
            raise EngineDeadError()

293
        is_pooling = isinstance(params, PoolingParams)
294
295
296

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
297

298
        # Convert Input --> Request.
299
300
301
302
303
304
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
        else:
            assert prompt_text is None
            logger.warning_once(
                "Processor has been moved under OpenAIServing and will "
305
306
307
308
309
310
311
312
313
314
315
316
317
                "be removed from AsyncLLM in v0.13."
            )
            request = self.processor.process_inputs(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
                data_parallel_rank,
            )
318
319
320
321
            if isinstance(prompt, str):
                prompt_text = prompt
            elif isinstance(prompt, Mapping):
                prompt_text = cast(str | None, prompt.get("prompt"))
322

323
        if is_pooling or params.n == 1:
324
            await self._add_request(request, prompt_text, None, 0, queue)
325
326
            return queue

327
328
329
330
331
        # Get the updated SamplingParams from the request, which
        # were cloned/updated in processor.process_inputs above.
        parent_params = request.sampling_params
        assert parent_params is not None

332
        # Fan out child requests (for n>1).
333
334
335
        parent_request = ParentRequest(request_id, parent_params)
        for idx in range(parent_params.n):
            request_id, child_params = parent_request.get_child_info(idx)
336
            child_request = request if idx == parent_params.n - 1 else copy(request)
337
            child_request.request_id = request_id
338
            child_request.sampling_params = child_params
339
340
341
            await self._add_request(
                child_request, prompt_text, parent_request, idx, queue
            )
342
        return queue
343

344
345
346
    async def _add_request(
        self,
        request: EngineCoreRequest,
347
348
        prompt: str | None,
        parent_req: ParentRequest | None,
349
350
351
        index: int,
        queue: RequestOutputCollector,
    ):
352
        # Add the request to OutputProcessor (this process).
353
        self.output_processor.add_request(request, prompt, parent_req, index, queue)
354

355
356
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
357

358
359
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
360
361
362
363
364
365

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
366
    async def generate(
367
        self,
368
        prompt: EngineCoreRequest | PromptType,
369
370
        sampling_params: SamplingParams,
        request_id: str,
371
        *,
372
373
374
375
        prompt_text: str | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
376
        priority: int = 0,
377
        data_parallel_rank: int | None = None,
378
379
380
381
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
382
            * 2) Processing the Input.
383
384
385
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

386
387
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
388
389
390
391
392
393
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

394
395
396
397
        if (
            self.vllm_config.cache_config.kv_sharing_fast_prefill
            and sampling_params.prompt_logprobs
        ):
398
399
400
            raise ValueError(
                "--kv-sharing-fast-prefill produces incorrect logprobs for "
                "prompt tokens, please disable it when the requests need "
401
402
                "prompt logprobs"
            )
403

404
405
406
407
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
408
            self._run_output_handler()
409

410
411
412
413
414
415
416
417
418
419
            if tokenization_kwargs is None:
                tokenization_kwargs = {}
                truncate_prompt_tokens = sampling_params.truncate_prompt_tokens

                _validate_truncation_size(
                    self.model_config.max_model_len,
                    truncate_prompt_tokens,
                    tokenization_kwargs,
                )

420
421
422
423
424
425
426
427
428
429
430
            q = await self.add_request(
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
                trace_headers=trace_headers,
                priority=priority,
                data_parallel_rank=data_parallel_rank,
                prompt_text=prompt_text,
            )
431

432
433
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
434
435
            finished = False
            while not finished:
436
437
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
438
                out = q.get_nowait() or await q.get()
439

440
                # Note: both OutputProcessor and EngineCore handle their
441
                # own request cleanup based on finished.
442
                finished = out.finished
443
                assert isinstance(out, RequestOutput)
444
445
                yield out

446
        # If the request is disconnected by the client, generate()
447
448
449
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
450
            await self.abort(request_id)
451
452
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
453
            raise
454

455
456
457
458
459
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
460

461
462
463
464
465
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
466

467
        # Unexpected error in the generate() task (possibly recoverable).
468
        except Exception as e:
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
485
        logger_manager = self.logger_manager
486
        processor = self.processor
487
488
489
490
491
492
493
494

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

495
496
497
                    iteration_stats = (
                        IterationStats() if (log_stats and num_outputs) else None
                    )
498
499
500
501

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
502
                    if num_outputs <= envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
503
                        slices = (outputs.outputs,)
504
505
506
                    else:
                        slices = np.array_split(
                            outputs.outputs,
507
                            cdiv(num_outputs, envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE),
508
                        )
509
510
511
512

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
513
514
                            outputs_slice, outputs.timestamp, iteration_stats
                        )
515
516
517
518
519
520
521
522
523
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
524
525
                            processed_outputs.reqs_to_abort
                        )
526
527
528
529

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
530
531
532
                    if logger_manager:
                        logger_manager.record(
                            engine_idx=outputs.engine_index,
533
534
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
535
                            mm_cache_stats=processor.stat_mm_cache(),
536
537
538
539
540
541
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
542

543
    async def abort(self, request_id: str | Iterable[str]) -> None:
544
        """Abort RequestId in OutputProcessor and EngineCore."""
545

546
547
548
        request_ids = (
            (request_id,) if isinstance(request_id, str) else as_list(request_id)
        )
549
550
        all_request_ids = self.output_processor.abort_requests(request_ids)
        await self.engine_core.abort_requests_async(all_request_ids)
551

552
        if self.log_requests:
553
            logger.info("Aborted request(s) %s.", ",".join(request_ids))
554

555
    async def encode(
556
557
558
559
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
560
561
        lora_request: LoRARequest | None = None,
        trace_headers: Mapping[str, str] | None = None,
562
        priority: int = 0,
563
564
        truncate_prompt_tokens: int | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            self._run_output_handler()

586
            if tokenization_kwargs is None:
587
                tokenization_kwargs = {}
588
589
590
591
592
593
            _validate_truncation_size(
                self.model_config.max_model_len,
                truncate_prompt_tokens,
                tokenization_kwargs,
            )

594
595
596
597
598
            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
599
                tokenization_kwargs=tokenization_kwargs,
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
643

644
    @property
645
    def tokenizer(self) -> AnyTokenizer | None:
646
        return self.processor.tokenizer
647

648
    @tokenizer.setter
649
    def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
650
        self.processor.tokenizer = tokenizer
651

652
    async def get_tokenizer(self) -> AnyTokenizer:
653
        if self.tokenizer is None:
654
655
656
            raise ValueError(
                "Unable to get tokenizer because skip_tokenizer_init is True"
            )
657

658
        return self.tokenizer
659
660

    async def is_tracing_enabled(self) -> bool:
661
        return self.observability_config.otlp_traces_endpoint is not None  # type: ignore
662

663
    async def do_log_stats(self) -> None:
664
665
        if self.logger_manager:
            self.logger_manager.log()
666
667
668

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
669
670
        if self.errored:
            raise self.dead_error
671
672

    async def start_profile(self) -> None:
673
674
675
676
        coros = [self.engine_core.profile_async(True)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.start))
        await asyncio.gather(*coros)
677
678

    async def stop_profile(self) -> None:
679
680
681
682
        coros = [self.engine_core.profile_async(False)]
        if self.profiler is not None:
            coros.append(asyncio.to_thread(self.profiler.stop))
        await asyncio.gather(*coros)
683

684
    async def reset_mm_cache(self) -> None:
685
        self.processor.clear_mm_cache()
686
687
        await self.engine_core.reset_mm_cache_async()

688
    async def reset_prefix_cache(self, device: Device | None = None) -> None:
689
690
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
691
692
        await self.engine_core.reset_prefix_cache_async()

693
    async def sleep(self, level: int = 1) -> None:
694
        await self.reset_prefix_cache()
695
696
        await self.engine_core.sleep_async(level)

697
698
699
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

700
    async def wake_up(self, tags: list[str] | None = None) -> None:
701
        await self.engine_core.wake_up_async(tags)
702

703
704
705
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

706
707
708
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

709
    async def add_lora(self, lora_request: LoRARequest) -> bool:
710
        """Load a new LoRA adapter into the engine for future requests."""
711
712
713
714
715
716
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

717
    async def list_loras(self) -> set[int]:
718
719
720
721
722
723
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
724

725
726
727
    async def collective_rpc(
        self,
        method: str,
728
        timeout: float | None = None,
729
        args: tuple = (),
730
        kwargs: dict | None = None,
731
    ):
732
733
734
735
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
736
737
            method, timeout, args, kwargs
        )
738

739
740
741
742
743
744
745
746
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

747
            logger.info("Engines are still running, waiting for requests to drain...")
748
749
            await asyncio.sleep(1)  # Wait 1 second before checking again

750
751
752
753
        raise TimeoutError(
            f"Timeout reached after {drain_timeout} seconds "
            "waiting for requests to drain."
        )
754

755
756
757
    async def scale_elastic_ep(
        self, new_data_parallel_size: int, drain_timeout: int = 300
    ):
758
759
760
761
762
763
764
765
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
766
        old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
767
        if old_data_parallel_size == new_data_parallel_size:
768
769
770
771
            logger.info(
                "Data parallel size is already %s, skipping scale",
                new_data_parallel_size,
            )
772
773
            return
        logger.info(
774
775
776
            "Waiting for requests to drain before scaling up to %s engines...",
            new_data_parallel_size,
        )
777
778
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
779
780
781
            "Requests have been drained, proceeding with scale to %s engines",
            new_data_parallel_size,
        )
782
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
783
        self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
784
785

        # recreate stat loggers
786
787
788
789
790
791
        if new_data_parallel_size > old_data_parallel_size and self.log_stats:
            # TODO(rob): fix this after talking with Ray team.
            # This resets all the prometheus metrics since we
            # unregister during initialization. Need to understand
            # the intended behavior here better.
            self.logger_manager = StatLoggerManager(
792
                vllm_config=self.vllm_config,
793
                engine_idxs=list(range(new_data_parallel_size)),
794
795
796
                custom_stat_loggers=None,
            )

797
798
    @property
    def is_running(self) -> bool:
799
800
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
801
802
803

    @property
    def is_stopped(self) -> bool:
804
        return self.errored
805
806
807

    @property
    def errored(self) -> bool:
808
        return self.engine_core.resources.engine_dead or not self.is_running
809
810
811

    @property
    def dead_error(self) -> BaseException:
812
        return EngineDeadError()