async_llm.py 27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import time
5
from collections.abc import AsyncGenerator, Mapping
6
from copy import copy
7
from typing import Any, Optional, Union
8

9
10
import numpy as np

11
import vllm.envs as envs
12
13
14
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
15
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
16
from vllm.inputs import PromptType
17
from vllm.inputs.preprocess import InputPreprocessor
18
19
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
21
from vllm.outputs import PoolingRequestOutput, RequestOutput
22
23
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
24
from vllm.sampling_params import SamplingParams
25
26
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
27
28
29
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
30
from vllm.utils import Device, cdiv
31
from vllm.v1.engine import EngineCoreRequest
32
from vllm.v1.engine.core_client import EngineCoreClient
33
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
34
35
from vllm.v1.engine.output_processor import (OutputProcessor,
                                             RequestOutputCollector)
36
from vllm.v1.engine.parallel_sampling import ParentRequest
37
from vllm.v1.engine.processor import Processor
38
from vllm.v1.executor.abstract import Executor
39
40
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
                                     setup_default_loggers)
41
from vllm.v1.metrics.prometheus import shutdown_prometheus
42
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
43
44
45
46
47
48
49
50
51

logger = init_logger(__name__)


class AsyncLLM(EngineClient):

    def __init__(
        self,
        vllm_config: VllmConfig,
52
        executor_class: type[Executor],
53
54
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
55
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
56
57
58
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
59
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
60
61
        client_addresses: Optional[dict[str, str]] = None,
        client_index: int = 0,
62
    ) -> None:
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        """
        Create an AsyncLLM.

        Args:
            vllm_config: global configuration.
            executor_class: an Executor impl, e.g. MultiprocExecutor.
            log_stats: Whether to log stats.
            usage_context: Usage context of the LLM.
            mm_registry: Multi-modal registry.
            use_cached_outputs: Whether to use cached outputs.
            log_requests: Whether to log requests.
            start_engine_loop: Whether to start the engine loop.
            stat_loggers: customized stat loggers for the engine.
                If not provided, default stat loggers will be used.
                PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
                IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.

        Returns:
            None
        """
83
84
85
86
87
88
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")
89

90
91
92
        # Ensure we can serialize custom transformer configs
        maybe_register_config_serialize_by_value()

93
        self.model_config = vllm_config.model_config
94
        self.vllm_config = vllm_config
95
96
        self.log_requests = log_requests
        self.log_stats = log_stats
97
98

        # Set up stat loggers; independent set for each DP rank.
99
100
101
102
103
104
        self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
            vllm_config=vllm_config,
            log_stats=self.log_stats,
            engine_num=vllm_config.parallel_config.data_parallel_size,
            custom_stat_loggers=stat_loggers,
        )
105
106
107
108
109

        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
110
            lora_config=vllm_config.lora_config)
111
112

        # Processor (converts Inputs --> EngineCoreRequests).
113
        self.processor = Processor(
114
            vllm_config=vllm_config,
115
            tokenizer=self.tokenizer,
116
            mm_registry=mm_registry,
117
        )
118

119
120
121
        # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
                                                log_stats=self.log_stats)
122
123

        # EngineCore (starts the engine in background process).
124
125

        self.engine_core = EngineCoreClient.make_async_mp_client(
126
127
            vllm_config=vllm_config,
            executor_class=executor_class,
128
            log_stats=self.log_stats,
129
130
            client_addresses=client_addresses,
            client_index=client_index,
131
        )
132
133
134
        if self.stat_loggers:
            for stat_logger in self.stat_loggers[0]:
                stat_logger.log_engine_initialized()
135
        self.output_handler: Optional[asyncio.Task] = None
136
137
138
139
140
141
        try:
            # Start output handler eagerly if we are in the asyncio eventloop.
            asyncio.get_running_loop()
            self._run_output_handler()
        except RuntimeError:
            pass
142

143
144
145
146
147
148
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
149
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
150
151
        disable_log_requests: bool = False,
        disable_log_stats: bool = False,
152
153
        client_addresses: Optional[dict[str, str]] = None,
        client_index: int = 0,
154
155
156
157
158
159
160
161
162
163
164
165
166
    ) -> "AsyncLLM":
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

        # Create the LLMEngine.
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            start_engine_loop=start_engine_loop,
167
            stat_loggers=stat_loggers,
168
169
170
            log_requests=not disable_log_requests,
            log_stats=not disable_log_stats,
            usage_context=usage_context,
171
172
            client_addresses=client_addresses,
            client_index=client_index,
173
174
        )

175
176
177
178
179
180
    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
181
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
182
    ) -> "AsyncLLM":
183
184
185
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
186
        vllm_config = engine_args.create_engine_config(usage_context)
187
        executor_class = Executor.get_class(vllm_config)
188
189
190
191
192
193
194
195
196

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
197
            stat_loggers=stat_loggers,
198
199
        )

200
201
202
    def __del__(self):
        self.shutdown()

203
204
205
    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

206
207
        shutdown_prometheus()

208
209
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
210
211
212
213
214
215
216
217
218
219
220

        if handler := getattr(self, "output_handler", None):
            handler.cancel()

    async def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
221
        tokenization_kwargs: Optional[dict[str, Any]] = None,
222
223
224
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
225
        data_parallel_rank: Optional[int] = None,
226
    ) -> RequestOutputCollector:
227
228
        """Add new request to the AsyncLLM."""

229
230
231
        if self.errored:
            raise EngineDeadError()

232
        is_pooling = isinstance(params, PoolingParams)
233
234
235

        # Create a new output collector for the request.
        queue = RequestOutputCollector(output_kind=params.output_kind)
236

237
        # Convert Input --> Request.
238
239
        prompt_str, request = self.processor.process_inputs(
            request_id, prompt, params, arrival_time, lora_request,
240
            tokenization_kwargs, trace_headers, prompt_adapter_request,
241
            priority, data_parallel_rank)
242

243
        if is_pooling or params.n == 1:
244
            await self._add_request(request, prompt_str, None, 0, queue)
245
246
247
248
            return queue

        # Fan out child requests (for n>1).
        parent_request = ParentRequest(request_id, params)
249
        for idx in range(params.n):
250
            request_id, params = parent_request.get_child_info(idx)
251
            child_request = request if idx == params.n - 1 else copy(request)
252
253
            child_request.request_id = request_id
            child_request.sampling_params = params
254
255
            await self._add_request(child_request, prompt_str, parent_request,
                                    idx, queue)
256
        return queue
257

258
    async def _add_request(self, request: EngineCoreRequest,
259
                           prompt: Optional[str],
260
                           parent_req: Optional[ParentRequest], index: int,
261
                           queue: RequestOutputCollector):
262

263
        # Add the request to OutputProcessor (this process).
264
265
        self.output_processor.add_request(request, prompt, parent_req, index,
                                          queue)
266

267
268
        # Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(request)
269

270
271
        if self.log_requests:
            logger.info("Added request %s.", request.request_id)
272
273
274
275
276
277

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
278
    async def generate(
279
280
281
282
283
284
285
286
        self,
        prompt: PromptType,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
287
        data_parallel_rank: Optional[int] = None,
288
289
290
291
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
292
            * 2) Processing the Input.
293
294
295
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

296
297
        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
298
299
300
301
302
303
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

304
305
306
307
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
308
            self._run_output_handler()
309
310

            q = await self.add_request(
311
312
313
314
315
316
317
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
318
                data_parallel_rank=data_parallel_rank,
319
            )
320

321
322
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
323
324
            finished = False
            while not finished:
325
326
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
327
                out = q.get_nowait() or await q.get()
328

329
                # Note: both OutputProcessor and EngineCore handle their
330
                # own request cleanup based on finished.
331
                finished = out.finished
332
333
                yield out

334
        # If the request is disconnected by the client, generate()
335
336
337
        # is cancelled or the generator is garbage collected. So,
        # we abort the request if we end up here.
        except (asyncio.CancelledError, GeneratorExit):
338
            await self.abort(request_id)
339
340
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
341
            raise
342

343
344
345
346
347
        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise
348

349
350
351
352
353
        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise
354

355
        # Unexpected error in the generate() task (possibly recoverable).
356
        except Exception as e:
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e

    def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        if self.output_handler is not None:
            return

        # Ensure that the task doesn't have a circular ref back to the AsyncLLM
        # object, or else it won't be garbage collected and cleaned up properly.
        engine_core = self.engine_core
        output_processor = self.output_processor
        log_stats = self.log_stats
        stat_loggers = self.stat_loggers if log_stats else None

        async def output_handler():
            try:
                while True:
                    # 1) Pull EngineCoreOutputs from the EngineCore.
                    outputs = await engine_core.get_output_async()
                    num_outputs = len(outputs.outputs)

                    iteration_stats = IterationStats() if (
                        log_stats and num_outputs) else None

                    # Split outputs into chunks of at most
                    # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
                    # event loop for too long.
                    if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
                        slices = (outputs.outputs, )
                    else:
                        slices = np.array_split(
                            outputs.outputs,
                            cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))

                    for i, outputs_slice in enumerate(slices):
                        # 2) Process EngineCoreOutputs.
                        processed_outputs = output_processor.process_outputs(
                            outputs_slice, outputs.timestamp, iteration_stats)
                        # NOTE: RequestOutputs are pushed to their queues.
                        assert not processed_outputs.request_outputs

                        # Allow other asyncio tasks to run between chunks
                        if i + 1 < len(slices):
                            await asyncio.sleep(0)

                        # 3) Abort any reqs that finished due to stop strings.
                        await engine_core.abort_requests_async(
                            processed_outputs.reqs_to_abort)

                    # 4) Logging.
                    # TODO(rob): make into a coroutine and launch it in
                    # background thread once Prometheus overhead is non-trivial.
                    if stat_loggers:
                        AsyncLLM._record_stats(
                            stat_loggers[outputs.engine_index],
                            scheduler_stats=outputs.scheduler_stats,
                            iteration_stats=iteration_stats,
                        )
            except Exception as e:
                logger.exception("AsyncLLM output_handler failed.")
                output_processor.propagate_error(e)

        self.output_handler = asyncio.create_task(output_handler())
424
425

    async def abort(self, request_id: str) -> None:
426
        """Abort RequestId in OutputProcessor and EngineCore."""
427

428
        request_ids = self.output_processor.abort_requests((request_id, ))
429
430
        await self.engine_core.abort_requests_async(request_ids)

431
432
        if self.log_requests:
            logger.info("Aborted request %s.", request_id)
433

434
    @staticmethod
435
    def _record_stats(
436
        stat_loggers: list[StatLoggerBase],
437
        scheduler_stats: Optional[SchedulerStats],
438
        iteration_stats: Optional[IterationStats],
439
    ):
440
441
442
        """static so that it can be used from the output_handler task
        without a circular ref to AsyncLLM."""
        for stat_logger in stat_loggers:
443
444
            stat_logger.record(scheduler_stats=scheduler_stats,
                               iteration_stats=iteration_stats)
445

446
    async def encode(
447
448
449
450
451
452
453
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
            * 2) Processing the Input.
            * 3) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task,
        pulling outputs from EngineCore and putting them into the
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            self._run_output_handler()

            q = await self.add_request(
                request_id,
                prompt,
                pooling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                priority=priority,
            )

            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            finished = False
            while not finished:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() or await q.get()
                assert isinstance(out, PoolingRequestOutput)
                # Note: both OutputProcessor and EngineCore handle their
                # own request cleanup based on finished.
                finished = out.finished
                yield out

        # If the request is disconnected by the client, generate()
        # is cancelled. So, we abort the request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s aborted.", request_id)
            raise

        # Engine is dead. Do not abort since we shut down.
        except EngineDeadError:
            if self.log_requests:
                logger.info("Request %s failed (engine dead).", request_id)
            raise

        # Request validation error.
        except ValueError:
            if self.log_requests:
                logger.info("Request %s failed (bad request).", request_id)
            raise

        # Unexpected error in the generate() task (possibly recoverable).
        except Exception as e:
            await self.abort(request_id)
            if self.log_requests:
                logger.info("Request %s failed.", request_id)
            raise EngineGenerateError() from e
523

524
525
526
    async def get_vllm_config(self) -> VllmConfig:
        return self.vllm_config

527
528
529
530
531
532
    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def get_decoding_config(self):
        raise ValueError("Not Supported on V1 yet.")

533
534
535
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.processor.input_preprocessor

536
537
538
539
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
540
        return self.tokenizer.get_lora_tokenizer(lora_request)
541
542
543
544
545
546
547
548
549

    async def is_tracing_enabled(self) -> bool:
        return False

    async def do_log_stats(
        self,
        scheduler_outputs=None,
        model_output=None,
    ) -> None:
550
551
552
        for loggers in self.stat_loggers:
            for stat_logger in loggers:
                stat_logger.log()
553
554
555

    async def check_health(self) -> None:
        logger.debug("Called check_health.")
556
557
        if self.errored:
            raise self.dead_error
558
559

    async def start_profile(self) -> None:
560
        await self.engine_core.profile_async(True)
561
562

    async def stop_profile(self) -> None:
563
        await self.engine_core.profile_async(False)
564

565
566
567
568
569
    async def reset_mm_cache(self) -> None:
        self.processor.mm_registry.reset_processor_cache()
        self.processor.mm_input_cache_client.reset()
        await self.engine_core.reset_mm_cache_async()

570
571
572
573
    async def reset_prefix_cache(self,
                                 device: Optional[Device] = None) -> None:
        if device == Device.CPU:
            raise ValueError("Not supported on CPU.")
574
575
        await self.engine_core.reset_prefix_cache_async()

576
577
578
    async def sleep(self, level: int = 1) -> None:
        await self.engine_core.sleep_async(level)

579
580
    async def wake_up(self, tags: Optional[list[str]] = None) -> None:
        await self.engine_core.wake_up_async(tags)
581

582
583
584
    async def is_sleeping(self) -> bool:
        return await self.engine_core.is_sleeping_async()

585
    async def add_lora(self, lora_request: LoRARequest) -> bool:
586
        """Load a new LoRA adapter into the engine for future requests."""
587
588
589
590
591
592
        return await self.engine_core.add_lora_async(lora_request)

    async def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return await self.engine_core.remove_lora_async(lora_id)

593
    async def list_loras(self) -> set[int]:
594
595
596
597
598
599
        """List all registered adapters."""
        return await self.engine_core.list_loras_async()

    async def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return await self.engine_core.pin_lora_async(lora_id)
600

601
602
603
604
605
606
607
608
609
610
611
    async def collective_rpc(self,
                             method: str,
                             timeout: Optional[float] = None,
                             args: tuple = (),
                             kwargs: Optional[dict] = None):
        """
        Perform a collective RPC call to the given path.
        """
        return await self.engine_core.collective_rpc_async(
            method, timeout, args, kwargs)

612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
    async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
        """Wait for all requests to be drained."""
        start_time = time.time()
        while time.time() - start_time < drain_timeout:
            if not self.engine_core.dp_engines_running():
                logger.info("Engines are idle, requests have been drained")
                return

            logger.info(
                "Engines are still running, waiting for requests to drain...")
            await asyncio.sleep(1)  # Wait 1 second before checking again

        raise TimeoutError(f"Timeout reached after {drain_timeout} seconds "
                           "waiting for requests to drain.")

    async def scale_elastic_ep(self,
                               new_data_parallel_size: int,
                               drain_timeout: int = 300):
        """
        Scale up or down the data parallel size by adding or removing
        engine cores.
        Args:
            new_data_parallel_size: The new number of data parallel workers
            drain_timeout:
                Maximum time to wait for requests to drain (seconds)
        """
        old_data_parallel_size = \
            self.vllm_config.parallel_config.data_parallel_size
        if old_data_parallel_size == new_data_parallel_size:
            logger.info("Data parallel size is already %s, skipping scale",
                        new_data_parallel_size)
            return
        logger.info(
            "Waiting for requests to drain before "
            "scaling up to %s engines...", new_data_parallel_size)
        await self.wait_for_requests_to_drain(drain_timeout)
        logger.info(
            "Requests have been drained, proceeding with scale "
            "to %s engines", new_data_parallel_size)
        await self.engine_core.scale_elastic_ep(new_data_parallel_size)
        self.vllm_config.parallel_config.data_parallel_size = \
            new_data_parallel_size

        # recreate stat loggers
        if new_data_parallel_size > old_data_parallel_size:
            stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
                vllm_config=self.vllm_config,
                log_stats=self.log_stats,
                engine_num=new_data_parallel_size,
                custom_stat_loggers=None,
            )
            num_new_engines = len(stat_loggers) - len(self.stat_loggers)
            self.stat_loggers.extend(stat_loggers[-num_new_engines:])
        else:
            for _ in range(old_data_parallel_size - new_data_parallel_size):
                self.stat_loggers.pop()

669
670
    @property
    def is_running(self) -> bool:
671
672
        # Is None before the loop is started.
        return self.output_handler is None or not self.output_handler.done()
673
674
675

    @property
    def is_stopped(self) -> bool:
676
        return self.errored
677
678
679

    @property
    def errored(self) -> bool:
680
        return self.engine_core.resources.engine_dead or not self.is_running
681
682
683

    @property
    def dead_error(self) -> BaseException:
684
        return EngineDeadError()