core.py 24.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import os
3
import queue
4
import signal
5
import sys
6
7
import threading
import time
8
from concurrent.futures import Future
9
from inspect import isclass, signature
10
from logging import DEBUG
11
from typing import Any, Callable, Optional, TypeVar, Union
12

13
import msgspec
14
import psutil
15
16
17
import zmq
import zmq.asyncio

18
19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
23
24
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
25
26
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
                        zmq_socket_ctx)
27
28
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
29
from vllm.v1.core.sched.interface import SchedulerInterface
30
31
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
32
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
33
                            EngineCoreRequestType, UtilityOutput)
34
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
35
from vllm.v1.executor.abstract import Executor
36
from vllm.v1.kv_cache_interface import KVCacheConfig
37
from vllm.v1.outputs import ModelRunnerOutput
38
from vllm.v1.request import Request, RequestStatus
39
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
40
from vllm.v1.structured_output import StructuredOutputManager
41
42
43
44
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

45
POLLING_TIMEOUT_S = 2.5
46

47
48
_R = TypeVar('_R')  # Return type for collective_rpc

49
50
51
52
53
54
55

class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
56
        executor_class: type[Executor],
57
        log_stats: bool,
58
    ):
59
        assert vllm_config.model_config.runner_type != "pooling"
60

61
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
62
63
                    VLLM_VERSION, vllm_config)

64
65
        self.log_stats = log_stats

66
67
68
69
        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
70
71
72
        num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
            self._initialize_kv_caches(vllm_config)

73
74
75
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

76
77
        self.structured_output_manager = StructuredOutputManager(vllm_config)

78
        # Setup scheduler.
79
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
80
81
82
83
84
85
86
87
88
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
89
90
91
92
93
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
94

95
        self.scheduler: SchedulerInterface = Scheduler(
96
97
98
99
            scheduler_config=vllm_config.scheduler_config,
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            lora_config=vllm_config.lora_config,
100
101
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
102
103
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
104
            log_stats=self.log_stats,
105
        )
106

107
        # Setup MM Input Mapper.
108
        self.mm_input_cache_server = MMInputCacheServer(
109
            vllm_config.model_config)
110

111
112
113
114
115
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
116
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
117
118
119
120
121
122
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

123
124
    def _initialize_kv_caches(
            self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
125
        start = time.time()
126

127
        # Get all kv cache needed by the model
128
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
129
130
131

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
132
        available_gpu_memory = self.model_executor.determine_available_memory()
133

134
        assert len(kv_cache_specs) == len(available_gpu_memory)
135
        # Get the kv cache tensor size
136
137
138
139
140
141
142
143
144
145
146
147
148
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
149
        # an arbitrary one to initialize the scheduler.
150
151
152
153
154
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
155
        num_cpu_blocks = 0
156
        scheduler_kv_cache_config = kv_cache_configs[0]
157
158

        # Initialize kv cache and warmup the execution
159
        self.model_executor.initialize_from_config(kv_cache_configs)
160

161
162
163
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
164
        return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
165
166
167

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
168
169

        if request.mm_hashes is not None:
170
171
172
173
174
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
175
            assert request.mm_inputs is not None
176
            request.mm_inputs = self.mm_input_cache_server.get_and_update(
177
                request.mm_inputs, request.mm_hashes)
178

179
        req = Request.from_engine_core_request(request)
180
181
        if req.use_structured_output:
            # Start grammar compilation asynchronously
182
            self.structured_output_manager.grammar_init(req)
183

184
185
        self.scheduler.add_request(req)

186
    def abort_requests(self, request_ids: list[str]):
187
188
189
190
191
192
193
194
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

195
    def step(self) -> EngineCoreOutputs:
196
197
        """Schedule, execute, and make output."""

198
199
200
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
201
            return EngineCoreOutputs(
202
203
204
                outputs=[],
                scheduler_stats=self.scheduler.make_stats(),
            )
205
206
207
        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
208
            scheduler_output, output)  # type: ignore
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        return engine_core_outputs

    def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
        1. Try to schedule a new batch if there are unscheduled requests
        and the job queue is not full. If a new batch is scheduled, directly
        return an empty engine core output. In other words, we won't check
        and return model outputs before the batch queue is full.
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
        # If there are unscheduled requests and the job queue
        # is not full, schedule a new batch. Note that this is not blocking.
        if (self.scheduler.get_num_unscheduled_requests() > 0
                and not self.batch_queue.full()):
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

240
241
242
243
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
244
        # block until the first batch in the job queue is finished.
245
246
247
248
249
250
251
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
            # Blocking until the first result is available.
            model_output = future.result()
            self.batch_queue.task_done()
            engine_core_outputs = self.scheduler.update_from_output(
                scheduler_output, model_output)
252

253
254
        return engine_core_outputs

255
256
257
    def shutdown(self):
        self.model_executor.shutdown()

258
    def profile(self, is_start: bool = True):
259
        self.model_executor.profile(is_start)
260

261
262
263
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

264
265
266
267
268
269
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

    def wake_up(self):
        self.model_executor.wake_up()

270
271
272
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

273
274
275
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

276
277
278
279
280
281
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

282
    def list_loras(self) -> set[int]:
283
284
285
286
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
287

288
289
290
291
292
293
294
295
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

296
297
298
299
300
301
302
303

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
304
        vllm_config: VllmConfig,
305
        executor_class: type[Executor],
306
        log_stats: bool,
307
        engine_index: int = 0,
308
    ):
309
        super().__init__(vllm_config, executor_class, log_stats)
310
311
312
313
314
315

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
316
        self.input_queue: queue.Queue[tuple[EngineCoreRequestType,
317
                                            Any]] = queue.Queue()
318
        self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
319
320
321
322
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
323
                         args=(output_path, engine_index),
324
325
                         daemon=True).start()

326
327
328
329
        self.global_unfinished_reqs = False

        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)
330
331

    @staticmethod
332
333
334
335
336
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        ready_pipe,
                        **kwargs):
337
338
        """Launch EngineCore busy loop in background process."""

339
340
341
342
343
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

344
345
346
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

347
348
349
350
351
352
353
354
355
356
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

357
        parent_process = psutil.Process().parent()
358
        engine_core: Optional[EngineCoreProc] = None
359
        try:
360
361
362
363
364
365
366
367
368
369
370
371
372
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
            if parallel_config.data_parallel_size > 1:
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
                engine_core = EngineCoreProc(*args, **kwargs)

            # Send Readiness signal to EngineClient.
            ready_pipe.send({"status": "READY"})

373
374
            engine_core.run_busy_loop()

375
        except SystemExit:
376
377
            logger.debug("EngineCore interrupted.")

378
379
380
        except Exception:
            traceback = get_exception_traceback()
            logger.error("EngineCore hit an exception: %s", traceback)
381
            parent_process.send_signal(signal.SIGUSR1)
382

383
384
385
386
        finally:
            if engine_core is not None:
                engine_core.shutdown()

387
388
389
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

390
391
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
392
            # 1) Poll the input queue until there is work to do.
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
        while not self.global_unfinished_reqs and not (
                self.scheduler.has_requests()):
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
            logger.debug(
                "EngineCore loop active - local unfinished: %s, finished: %s.",
                self.scheduler.has_unfinished_requests(),
                self.scheduler.has_finished_requests())

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

    def _process_engine_step(self):
        """Called only when there are unfinished local requests."""

        # Step the engine core.
        outputs = self.step_fn()
        # Put EngineCoreOutputs into the output queue.
        if outputs is not None:
            self.output_queue.put_nowait(outputs)
428

429
430
431
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
432

433
        if request_type == EngineCoreRequestType.ADD:
434
            self.add_request(request)
435
        elif request_type == EngineCoreRequestType.ABORT:
436
            self.abort_requests(request)
437
438
439
440
        elif request_type == EngineCoreRequestType.START_DP:
            if not self.global_unfinished_reqs:
                logger.debug("EngineCore starting idle loop.")
                self.global_unfinished_reqs = True
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        elif request_type == EngineCoreRequestType.UTILITY:
            call_id, method_name, args = request
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
                EngineCoreOutputs(utility_output=output))

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
468
469
470
471
472

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
473
474
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
475

476
        with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
477
478
479
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
480
                request_type = EngineCoreRequestType(bytes(type_frame.buffer))
481
482

                # Deserialize the request data.
483
484
485
                decoder = add_request_decoder if (
                    request_type
                    == EngineCoreRequestType.ADD) else generic_decoder
486
                request = decoder.decode(data_frame.buffer)
487
488

                # Push to input queue for core busy loop.
489
                self.input_queue.put_nowait((request_type, request))
490

491
    def process_output_socket(self, output_path: str, engine_index: int):
492
493
494
        """Output socket IO thread."""

        # Msgpack serialization encoding.
495
        encoder = MsgpackEncoder()
496
497
498
        # Reuse send buffer.
        buffer = bytearray()

499
        with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
500
            while True:
501
                outputs = self.output_queue.get()
502
                outputs.engine_index = engine_index
503
                encoder.encode_into(outputs, buffer)
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
                socket.send(buffer, copy=False)


ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
    ):
        # Add process-specific prefix to stdout and stderr before
        # we initialize the engine.
        from multiprocessing import current_process
        process_name = current_process().name
        pid = os.getpid()
        _add_prefix(sys.stdout, process_name, pid)
        _add_prefix(sys.stderr, process_name, pid)

        dp_size = vllm_config.parallel_config.data_parallel_size
        dp_rank = vllm_config.parallel_config.data_parallel_rank
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

        from vllm.platforms import current_platform
        if current_platform.is_cuda_alike():
            from vllm.platforms.cuda import device_id_to_physical_device_id
            tp_size = vllm_config.parallel_config.tensor_parallel_size
            os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
                str(device_id_to_physical_device_id(i))
                for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
                               tp_size))

        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

        # Initialize the engine after setting up environment.
        super().__init__(input_path, output_path, vllm_config, executor_class,
                         log_stats, dp_rank)

        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
        self.counter = 0

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

            local_unfinished_reqs = self.scheduler.has_unfinished_requests()

            if local_unfinished_reqs:
                # 2) Step the engine core.
                self._process_engine_step()

                # Check if we have now finished all requests.
                local_unfinished_reqs = (
                    self.scheduler.has_unfinished_requests())
            else:
                if self.scheduler.has_finished_requests():
                    # There are no unfinished requests, but there are some
                    # finished requests remaining to be removed from the
                    # batch state. This engine step won't perform a forward
                    # pass but will flush the finished requests to ensure
                    # up-to-date state is returned in the engine outputs.
                    self._process_engine_step()

                if not self.global_unfinished_reqs:
                    # All engines are idle.
                    continue

                # There must be unfinished requests in DP peers, run a
                # dummy forward pass.
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
            self.global_unfinished_reqs = self._has_global_unfinished_reqs(
                local_unfinished_reqs)

            if not self.global_unfinished_reqs:
                # Notify client that we are pausing the loop.
                self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

        # Optimization - only perform finish-sync all-reduce every 16 steps.
        self.counter += 1
        if self.counter != 16:
            return True
        self.counter = 0

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)