core.py 28.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import json
3
import os
4
import queue
5
import signal
6
import sys
7
8
import threading
import time
9
from collections import deque
10
from concurrent.futures import Future
11
from inspect import isclass, signature
12
from logging import DEBUG
13
from typing import Any, Callable, Optional, TypeVar, Union
14

15
import msgspec
16
17
import zmq

18
19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
23
24
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
25
from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx
26
27
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
28
from vllm.v1.core.sched.interface import SchedulerInterface
29
30
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
31
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
32
                            EngineCoreRequestType, UtilityOutput)
33
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
34
from vllm.v1.executor.abstract import Executor
35
from vllm.v1.kv_cache_interface import KVCacheConfig
36
from vllm.v1.outputs import ModelRunnerOutput
37
from vllm.v1.request import Request, RequestStatus
38
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
39
from vllm.v1.structured_output import StructuredOutputManager
40
41
42
43
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

44
POLLING_TIMEOUT_S = 2.5
45

46
47
_R = TypeVar('_R')  # Return type for collective_rpc

48
49
50
51

class EngineCore:
    """Inner loop of vLLM's Engine."""

52
53
54
55
56
    def __init__(self,
                 vllm_config: VllmConfig,
                 executor_class: type[Executor],
                 log_stats: bool,
                 executor_fail_callback: Optional[Callable] = None):
57
        assert vllm_config.model_config.runner_type != "pooling"
58

59
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
60
61
                    VLLM_VERSION, vllm_config)

62
63
        self.log_stats = log_stats

64
65
        # Setup Model.
        self.model_executor = executor_class(vllm_config)
66
67
68
        if executor_fail_callback is not None:
            self.model_executor.register_failure_callback(
                executor_fail_callback)
69
70

        # Setup KV Caches and update CacheConfig after profiling.
71
72
73
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

74
75
76
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

77
78
        self.structured_output_manager = StructuredOutputManager(vllm_config)

79
        # Setup scheduler.
80
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
81
82
83
84
85
86
87
88
89
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
90
91
92
93
94
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
95

96
        self.scheduler: SchedulerInterface = Scheduler(
97
            vllm_config=vllm_config,
98
99
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
100
101
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
102
            log_stats=self.log_stats,
103
        )
104

105
        # Setup MM Input Mapper.
106
        self.mm_input_cache_server = MirroredProcessingCache(
107
            vllm_config.model_config)
108

109
110
111
112
113
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
114
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
115
116
117
118
119
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)
120
        self.vllm_config = vllm_config
121

122
123
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
124
        start = time.time()
125

126
        # Get all kv cache needed by the model
127
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
128
129
130

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
131
        available_gpu_memory = self.model_executor.determine_available_memory()
132

133
        assert len(kv_cache_specs) == len(available_gpu_memory)
134
        # Get the kv cache tensor size
135
136
137
138
139
140
141
142
143
144
145
146
147
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
148
        # an arbitrary one to initialize the scheduler.
149
150
151
152
153
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
154
        num_cpu_blocks = 0
155
        scheduler_kv_cache_config = kv_cache_configs[0]
156
157

        # Initialize kv cache and warmup the execution
158
        self.model_executor.initialize_from_config(kv_cache_configs)
159

160
161
162
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
163
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
164
165
166

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
167
168

        if request.mm_hashes is not None:
169
170
171
172
173
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
174
            assert request.mm_inputs is not None
175
            request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
176
                request.mm_inputs, request.mm_hashes)
177

178
        req = Request.from_engine_core_request(request)
179
180
        if req.use_structured_output:
            # Start grammar compilation asynchronously
181
            self.structured_output_manager.grammar_init(req)
182

183
184
        self.scheduler.add_request(req)

185
    def abort_requests(self, request_ids: list[str]):
186
187
188
189
190
191
192
193
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

194
    def step(self) -> EngineCoreOutputs:
195
196
        """Schedule, execute, and make output."""

197
198
199
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
200
            return EngineCoreOutputs(
201
202
203
                outputs=[],
                scheduler_stats=self.scheduler.make_stats(),
            )
204
205
206
        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
207
            scheduler_output, output)  # type: ignore
208

209
210
211
212
213
214
215
        return engine_core_outputs

    def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
216
217
218
219
        1. Try to schedule a new batch if the batch queue is not full.
        If a new batch is scheduled, directly return an empty engine core
        output. In other words, fulfilling the batch queue has a higher priority
        than getting model outputs.
220
221
222
223
224
225
226
227
228
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
229
230
231
232
        # Try to schedule a new batch if the batch queue is not full, but
        # the scheduler may return an empty batch if all requests are scheduled.
        # Note that this is not blocking.
        if not self.batch_queue.full():
233
234
235
236
237
238
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

239
240
241
242
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
243
        # block until the first batch in the job queue is finished.
244
245
246
247
        # TODO(comaniac): Ideally we should peek the first batch in the
        # job queue to check if it's finished before scheduling a new batch,
        # but peeking the first element in a queue is not thread-safe,
        # so we need more work.
248
249
250
251
252
253
254
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
            # Blocking until the first result is available.
            model_output = future.result()
            self.batch_queue.task_done()
            engine_core_outputs = self.scheduler.update_from_output(
                scheduler_output, model_output)
255

256
257
        return engine_core_outputs

258
    def shutdown(self):
259
        self.structured_output_manager.clear_backend()
260
261
        if self.model_executor:
            self.model_executor.shutdown()
262
263
        if self.scheduler:
            self.scheduler.shutdown()
264

265
    def profile(self, is_start: bool = True):
266
        self.model_executor.profile(is_start)
267

268
269
270
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

271
272
273
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

274
275
    def wake_up(self, tags: Optional[list[str]] = None):
        self.model_executor.wake_up(tags)
276

277
278
279
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

280
281
282
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

283
284
285
286
287
288
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

289
    def list_loras(self) -> set[int]:
290
291
292
293
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
294

295
296
297
298
299
300
301
302
303
304
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        self.model_executor.save_sharded_state(path=path,
                                               pattern=pattern,
                                               max_size=max_size)

305
306
307
308
309
310
311
312
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

313
314
315
316

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

317
318
    ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'

319
320
321
322
    def __init__(
        self,
        input_path: str,
        output_path: str,
323
        vllm_config: VllmConfig,
324
        executor_class: type[Executor],
325
        log_stats: bool,
326
        engine_index: int = 0,
327
    ):
328
329
330
331
332
333
334
        input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()

        executor_fail_callback = lambda: input_queue.put_nowait(
            (EngineCoreRequestType.EXECUTOR_FAILED, b''))

        super().__init__(vllm_config, executor_class, log_stats,
                         executor_fail_callback)
335

336
337
        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)
338
        self.engines_running = False
339

340
341
342
343
344
        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
345
346
        self.input_queue = input_queue
        self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
347
        threading.Thread(target=self.process_input_socket,
348
                         args=(input_path, engine_index),
349
                         daemon=True).start()
350
351
352
353
354
        self.output_thread = threading.Thread(
            target=self.process_output_socket,
            args=(output_path, engine_index),
            daemon=True)
        self.output_thread.start()
355
356

    @staticmethod
357
358
359
360
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        **kwargs):
361
362
        """Launch EngineCore busy loop in background process."""

363
364
365
366
367
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

368
369
370
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

371
372
373
374
375
376
377
378
379
380
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

381
        engine_core: Optional[EngineCoreProc] = None
382
        try:
383
384
385
386
387
388
389
390
391
392
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
            if parallel_config.data_parallel_size > 1:
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
                engine_core = EngineCoreProc(*args, **kwargs)

393
394
            engine_core.run_busy_loop()

395
        except SystemExit:
396
            logger.debug("EngineCore exiting.")
397
            raise
398
399
400
401
402
403
404
        except Exception as e:
            if engine_core is None:
                logger.exception("EngineCore failed to start.")
            else:
                logger.exception("EngineCore encountered a fatal error.")
                engine_core._send_engine_dead()
            raise e
405
406
407
408
        finally:
            if engine_core is not None:
                engine_core.shutdown()

409
410
411
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

412
413
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
414
            # 1) Poll the input queue until there is work to do.
415
416
417
418
419
420
421
422
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
423
        while not self.engines_running and not (self.scheduler.has_requests()):
424
425
426
427
428
429
430
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
431
            logger.debug("EngineCore loop active.")
432
433
434
435
436
437
438
439
440
441
442
443
444
445

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

    def _process_engine_step(self):
        """Called only when there are unfinished local requests."""

        # Step the engine core.
        outputs = self.step_fn()
        # Put EngineCoreOutputs into the output queue.
        if outputs is not None:
            self.output_queue.put_nowait(outputs)
446

447
448
449
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
450

451
        if request_type == EngineCoreRequestType.ADD:
452
            self.add_request(request)
453
        elif request_type == EngineCoreRequestType.ABORT:
454
            self.abort_requests(request)
455
456
457
458
459
460
461
462
463
464
465
466
467
        elif request_type == EngineCoreRequestType.UTILITY:
            call_id, method_name, args = request
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
                EngineCoreOutputs(utility_output=output))
468
469
470
471
472
        elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
            raise RuntimeError("Executor failed.")
        else:
            logger.error("Unrecognized input request type encountered: %s",
                         request_type)
473
474
475
476
477
478
479
480
481
482
483
484
485
486

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
487

488
489
490
491
492
493
494
495
496
497
498
499
    def _send_engine_dead(self):
        """Send EngineDead status to the EngineCoreClient."""

        # Put ENGINE_CORE_DEAD in the queue.
        self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)

        # Wait until msg sent by the daemon before shutdown.
        self.output_thread.join(timeout=5.0)
        if self.output_thread.is_alive():
            logger.fatal("vLLM shutdown signal from EngineCore failed "
                         "to send. Please report this issue.")

500
    def process_input_socket(self, input_path: str, engine_index: int):
501
502
503
        """Input socket IO thread."""

        # Msgpack serialization decoding.
504
505
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
506
507
508
509
510
511
512
513
        identity = engine_index.to_bytes(length=2, byteorder="little")

        with zmq_socket_ctx(input_path,
                            zmq.DEALER,
                            identity=identity,
                            bind=False) as socket:

            # Send ready message to front-end once input socket is connected.
514
515
516
517
518
519
            message_dict = {
                'type': 'READY',
                'num_gpu_blocks': self.vllm_config.cache_config.num_gpu_blocks,
            }
            message = json.dumps(message_dict).encode('utf-8')
            socket.send(message)
520
521
522

            while True:
                # (RequestType, RequestData)
523
                type_frame, *data_frames = socket.recv_multipart(copy=False)
524
                request_type = EngineCoreRequestType(bytes(type_frame.buffer))
525
526

                # Deserialize the request data.
527
528
529
                decoder = add_request_decoder if (
                    request_type
                    == EngineCoreRequestType.ADD) else generic_decoder
530
                request = decoder.decode(data_frames)
531
532

                # Push to input queue for core busy loop.
533
                self.input_queue.put_nowait((request_type, request))
534

535
    def process_output_socket(self, output_path: str, engine_index: int):
536
537
538
        """Output socket IO thread."""

        # Msgpack serialization encoding.
539
        encoder = MsgpackEncoder()
540
541
542
543
544
545
        # Send buffers to reuse.
        reuse_buffers: list[bytearray] = []
        # Keep references to outputs and buffers until zmq is finished
        # with them (outputs may contain tensors/np arrays whose
        # backing buffers were extracted for zero-copy send).
        pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
546

547
548
549
550
        # We must set linger to ensure the ENGINE_CORE_DEAD
        # message is sent prior to closing the socket.
        with zmq_socket_ctx(output_path, zmq.constants.PUSH,
                            linger=4000) as socket:
551
            while True:
552
                outputs = self.output_queue.get()
553
554
555
556
                if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
                    socket.send(outputs, copy=False)
                    break
                assert not isinstance(outputs, bytes)
557
                outputs.engine_index = engine_index
558
559
560
561
562
563

                # Reclaim buffers that zmq is finished with.
                while pending and pending[-1][0].done:
                    reuse_buffers.append(pending.pop()[2])

                buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
564
                buffers = encoder.encode_into(outputs, buffer)
565
566
567
568
569
570
571
572
573
                tracker = socket.send_multipart(buffers,
                                                copy=False,
                                                track=True)
                if not tracker.done:
                    ref = outputs if len(buffers) > 1 else None
                    pending.appendleft((tracker, ref, buffer))
                elif len(reuse_buffers) < 2:
                    # Keep at most 2 buffers to reuse.
                    reuse_buffers.append(buffer)
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
    ):
        # Add process-specific prefix to stdout and stderr before
        # we initialize the engine.
        from multiprocessing import current_process
        process_name = current_process().name
        pid = os.getpid()
        _add_prefix(sys.stdout, process_name, pid)
        _add_prefix(sys.stderr, process_name, pid)

        dp_size = vllm_config.parallel_config.data_parallel_size
        dp_rank = vllm_config.parallel_config.data_parallel_rank
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

        from vllm.platforms import current_platform
        if current_platform.is_cuda_alike():
            from vllm.platforms.cuda import device_id_to_physical_device_id
            tp_size = vllm_config.parallel_config.tensor_parallel_size
            os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
                str(device_id_to_physical_device_id(i))
                for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
                               tp_size))

612
        self.local_dp_rank = local_dp_rank
613
        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
614
        self.current_wave = 0
615
616
617
618
619
620
621
622
623
624
625
626
627
628

        # Initialize the engine after setting up environment.
        super().__init__(input_path, output_path, vllm_config, executor_class,
                         log_stats, dp_rank)

        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
        self.counter = 0

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
    def add_request(self, request: EngineCoreRequest):
        if request.current_wave != self.current_wave:
            if request.current_wave > self.current_wave:
                self.current_wave = request.current_wave
            elif not self.engines_running:
                # Request received for an already-completed wave, notify
                # front-end that we need to start the next one.
                self.output_queue.put_nowait(
                    EngineCoreOutputs(start_wave=self.current_wave))

        super().add_request(request)

    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        if request_type == EngineCoreRequestType.START_DP_WAVE:
            new_wave: int = request
            if new_wave >= self.current_wave:
                self.current_wave = new_wave
                if not self.engines_running:
                    logger.debug("EngineCore starting idle loop for wave %d.",
                                 new_wave)
                    self.engines_running = True
        else:
            super()._handle_client_request(request_type, request)

654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

            local_unfinished_reqs = self.scheduler.has_unfinished_requests()

            if local_unfinished_reqs:
                # 2) Step the engine core.
                self._process_engine_step()

                # Check if we have now finished all requests.
                local_unfinished_reqs = (
                    self.scheduler.has_unfinished_requests())
            else:
                if self.scheduler.has_finished_requests():
                    # There are no unfinished requests, but there are some
                    # finished requests remaining to be removed from the
                    # batch state. This engine step won't perform a forward
                    # pass but will flush the finished requests to ensure
                    # up-to-date state is returned in the engine outputs.
                    self._process_engine_step()

680
                if not self.engines_running:
681
682
683
684
685
686
687
688
                    # All engines are idle.
                    continue

                # There must be unfinished requests in DP peers, run a
                # dummy forward pass.
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
689
            self.engines_running = self._has_global_unfinished_reqs(
690
691
                local_unfinished_reqs)

692
693
694
695
696
697
698
699
            if not self.engines_running:
                if self.local_dp_rank == 0:
                    # Notify client that we are pausing the loop.
                    logger.debug("Wave %d finished, pausing engine loop.",
                                 self.current_wave)
                    self.output_queue.put_nowait(
                        EngineCoreOutputs(wave_complete=self.current_wave))
                self.current_wave += 1
700
701
702

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

703
        # Optimization - only perform finish-sync all-reduce every 24 steps.
704
        self.counter += 1
705
        if self.counter != 24:
706
707
708
709
710
            return True
        self.counter = 0

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)