core.py 24.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import os
3
import queue
4
import signal
5
import sys
6
7
import threading
import time
8
from concurrent.futures import Future
9
from inspect import isclass, signature
10
from logging import DEBUG
11
from typing import Any, Callable, Optional, TypeVar, Union
12

13
import msgspec
14
import psutil
15
16
17
import zmq
import zmq.asyncio

18
19
20
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
21
from vllm.logger import init_logger
22
from vllm.lora.request import LoRARequest
23
24
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
25
26
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
                        zmq_socket_ctx)
27
28
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
29
from vllm.v1.core.sched.interface import SchedulerInterface
30
31
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
32
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
33
                            EngineCoreRequestType, UtilityOutput)
34
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
35
from vllm.v1.executor.abstract import Executor
36
from vllm.v1.outputs import ModelRunnerOutput
37
from vllm.v1.request import Request, RequestStatus
38
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
39
from vllm.v1.structured_output import StructuredOutputManager
40
41
42
43
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

44
POLLING_TIMEOUT_S = 2.5
45

46
47
_R = TypeVar('_R')  # Return type for collective_rpc

48
49
50
51
52
53
54

class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
55
        executor_class: type[Executor],
56
        log_stats: bool,
57
    ):
58
        assert vllm_config.model_config.runner_type != "pooling"
59

60
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
61
62
                    VLLM_VERSION, vllm_config)

63
64
        self.log_stats = log_stats

65
66
67
68
69
        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
        num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
70
            vllm_config)
71
72
73
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

74
75
        self.structured_output_manager = StructuredOutputManager(vllm_config)

76
        # Setup scheduler.
77
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
78
79
80
81
82
83
84
85
86
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
87
88
89
90
91
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
92

93
        self.scheduler: SchedulerInterface = Scheduler(
94
95
96
97
            scheduler_config=vllm_config.scheduler_config,
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            lora_config=vllm_config.lora_config,
98
            speculative_config=vllm_config.speculative_config,
99
100
            include_finished_set=vllm_config.parallel_config.data_parallel_size
            > 1,
101
            log_stats=self.log_stats,
102
            structured_output_manager=self.structured_output_manager,
103
        )
104

105
        # Setup MM Input Mapper.
106
        self.mm_input_cache_server = MMInputCacheServer(
107
            vllm_config.model_config)
108

109
110
111
112
113
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
114
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
115
116
117
118
119
120
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

121
    def _initialize_kv_caches(self,
122
                              vllm_config: VllmConfig) -> tuple[int, int]:
123
        start = time.time()
124

125
        # Get all kv cache needed by the model
126
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
127
128
129

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
130
        available_gpu_memory = self.model_executor.determine_available_memory()
131

132
        assert len(kv_cache_specs) == len(available_gpu_memory)
133
        # Get the kv cache tensor size
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
        # an arbitrary one to get the number of blocks.
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
153
        num_cpu_blocks = 0
154
155

        # Initialize kv cache and warmup the execution
156
        self.model_executor.initialize_from_config(kv_cache_configs)
157

158
159
160
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
161
162
163
164
        return num_gpu_blocks, num_cpu_blocks

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
165
166

        if request.mm_hashes is not None:
167
168
169
170
171
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
172
            assert request.mm_inputs is not None
173
            request.mm_inputs = self.mm_input_cache_server.get_and_update(
174
                request.mm_inputs, request.mm_hashes)
175

176
        req = Request.from_engine_core_request(request)
177
178
        if req.use_structured_output:
            # Start grammar compilation asynchronously
179
            self.structured_output_manager.grammar_init(req)
180

181
182
        self.scheduler.add_request(req)

183
    def abort_requests(self, request_ids: list[str]):
184
185
186
187
188
189
190
191
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

192
    def step(self) -> EngineCoreOutputs:
193
194
        """Schedule, execute, and make output."""

195
196
197
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
198
            return EngineCoreOutputs(
199
200
201
                outputs=[],
                scheduler_stats=self.scheduler.make_stats(),
            )
202
203
204
        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
205
            scheduler_output, output)  # type: ignore
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        return engine_core_outputs

    def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
        1. Try to schedule a new batch if there are unscheduled requests
        and the job queue is not full. If a new batch is scheduled, directly
        return an empty engine core output. In other words, we won't check
        and return model outputs before the batch queue is full.
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
        # If there are unscheduled requests and the job queue
        # is not full, schedule a new batch. Note that this is not blocking.
        if (self.scheduler.get_num_unscheduled_requests() > 0
                and not self.batch_queue.full()):
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

237
238
239
240
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
241
        # block until the first batch in the job queue is finished.
242
243
244
245
246
247
248
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
            # Blocking until the first result is available.
            model_output = future.result()
            self.batch_queue.task_done()
            engine_core_outputs = self.scheduler.update_from_output(
                scheduler_output, model_output)
249

250
251
        return engine_core_outputs

252
253
254
    def shutdown(self):
        self.model_executor.shutdown()

255
    def profile(self, is_start: bool = True):
256
        self.model_executor.profile(is_start)
257

258
259
260
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

261
262
263
264
265
266
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

    def wake_up(self):
        self.model_executor.wake_up()

267
268
269
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

270
271
272
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

273
274
275
276
277
278
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

279
    def list_loras(self) -> set[int]:
280
281
282
283
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
284

285
286
287
288
289
290
291
292
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.model_executor.collective_rpc(method, timeout, args,
                                                  kwargs)

293
294
295
296
297
298
299
300

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
301
        vllm_config: VllmConfig,
302
        executor_class: type[Executor],
303
        log_stats: bool,
304
        engine_index: int = 0,
305
    ):
306
        super().__init__(vllm_config, executor_class, log_stats)
307
308
309
310
311
312

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
313
        self.input_queue: queue.Queue[tuple[EngineCoreRequestType,
314
                                            Any]] = queue.Queue()
315
        self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
316
317
318
319
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
320
                         args=(output_path, engine_index),
321
322
                         daemon=True).start()

323
324
325
326
        self.global_unfinished_reqs = False

        self.step_fn = (self.step if self.batch_queue is None else
                        self.step_with_batch_queue)
327
328

    @staticmethod
329
330
331
332
333
    def run_engine_core(*args,
                        dp_rank: int = 0,
                        local_dp_rank: int = 0,
                        ready_pipe,
                        **kwargs):
334
335
        """Launch EngineCore busy loop in background process."""

336
337
338
339
340
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

341
342
343
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

344
345
346
347
348
349
350
351
352
353
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

354
        parent_process = psutil.Process().parent()
355
        engine_core: Optional[EngineCoreProc] = None
356
        try:
357
358
359
360
361
362
363
364
365
366
367
368
369
            parallel_config: ParallelConfig = kwargs[
                "vllm_config"].parallel_config
            if parallel_config.data_parallel_size > 1:
                # Set data parallel rank for this engine process.
                parallel_config.data_parallel_rank = dp_rank
                parallel_config.data_parallel_rank_local = local_dp_rank
                engine_core = DPEngineCoreProc(*args, **kwargs)
            else:
                engine_core = EngineCoreProc(*args, **kwargs)

            # Send Readiness signal to EngineClient.
            ready_pipe.send({"status": "READY"})

370
371
            engine_core.run_busy_loop()

372
        except SystemExit:
373
374
            logger.debug("EngineCore interrupted.")

375
376
377
        except Exception:
            traceback = get_exception_traceback()
            logger.error("EngineCore hit an exception: %s", traceback)
378
            parent_process.send_signal(signal.SIGUSR1)
379

380
381
382
383
        finally:
            if engine_core is not None:
                engine_core.shutdown()

384
385
386
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

387
388
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
389
            # 1) Poll the input queue until there is work to do.
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
            self._process_input_queue()
            # 2) Step the engine core and return the outputs.
            self._process_engine_step()

    def _process_input_queue(self):
        """Exits when an engine step needs to be performed."""

        waited = False
        while not self.global_unfinished_reqs and not (
                self.scheduler.has_requests()):
            if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
                logger.debug("EngineCore waiting for work.")
                waited = True
            req = self.input_queue.get()
            self._handle_client_request(*req)

        if waited:
            logger.debug(
                "EngineCore loop active - local unfinished: %s, finished: %s.",
                self.scheduler.has_unfinished_requests(),
                self.scheduler.has_finished_requests())

        # Handle any more client requests.
        while not self.input_queue.empty():
            req = self.input_queue.get_nowait()
            self._handle_client_request(*req)

    def _process_engine_step(self):
        """Called only when there are unfinished local requests."""

        # Step the engine core.
        outputs = self.step_fn()
        # Put EngineCoreOutputs into the output queue.
        if outputs is not None:
            self.output_queue.put_nowait(outputs)
425

426
427
428
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
429

430
        if request_type == EngineCoreRequestType.ADD:
431
            self.add_request(request)
432
        elif request_type == EngineCoreRequestType.ABORT:
433
            self.abort_requests(request)
434
435
436
437
        elif request_type == EngineCoreRequestType.START_DP:
            if not self.global_unfinished_reqs:
                logger.debug("EngineCore starting idle loop.")
                self.global_unfinished_reqs = True
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        elif request_type == EngineCoreRequestType.UTILITY:
            call_id, method_name, args = request
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
                EngineCoreOutputs(utility_output=output))

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
465
466
467
468
469

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
470
471
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
472

473
        with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
474
475
476
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
477
                request_type = EngineCoreRequestType(bytes(type_frame.buffer))
478
479

                # Deserialize the request data.
480
481
482
                decoder = add_request_decoder if (
                    request_type
                    == EngineCoreRequestType.ADD) else generic_decoder
483
                request = decoder.decode(data_frame.buffer)
484
485

                # Push to input queue for core busy loop.
486
                self.input_queue.put_nowait((request_type, request))
487

488
    def process_output_socket(self, output_path: str, engine_index: int):
489
490
491
        """Output socket IO thread."""

        # Msgpack serialization encoding.
492
        encoder = MsgpackEncoder()
493
494
495
        # Reuse send buffer.
        buffer = bytearray()

496
        with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
497
            while True:
498
                outputs = self.output_queue.get()
499
                outputs.engine_index = engine_index
500
                encoder.encode_into(outputs, buffer)
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
                socket.send(buffer, copy=False)


ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)


class DPEngineCoreProc(EngineCoreProc):
    """ZMQ-wrapper for running EngineCore in background process
    in a data parallel context."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
        vllm_config: VllmConfig,
        executor_class: type[Executor],
        log_stats: bool,
    ):
        # Add process-specific prefix to stdout and stderr before
        # we initialize the engine.
        from multiprocessing import current_process
        process_name = current_process().name
        pid = os.getpid()
        _add_prefix(sys.stdout, process_name, pid)
        _add_prefix(sys.stderr, process_name, pid)

        dp_size = vllm_config.parallel_config.data_parallel_size
        dp_rank = vllm_config.parallel_config.data_parallel_rank
        local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

        assert dp_size > 1
        assert 0 <= local_dp_rank <= dp_rank < dp_size

        from vllm.platforms import current_platform
        if current_platform.is_cuda_alike():
            from vllm.platforms.cuda import device_id_to_physical_device_id
            tp_size = vllm_config.parallel_config.tensor_parallel_size
            os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
                str(device_id_to_physical_device_id(i))
                for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
                               tp_size))

        self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

        # Initialize the engine after setting up environment.
        super().__init__(input_path, output_path, vllm_config, executor_class,
                         log_stats, dp_rank)

        # Counts forward-passes of the model so that we can synchronize
        # finished with DP peers every N steps.
        self.counter = 0

    def shutdown(self):
        super().shutdown()
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)

    def run_busy_loop(self):
        """Core busy loop of the EngineCore for data parallel case."""

        # Loop until process is sent a SIGINT or SIGTERM
        while True:
            # 1) Poll the input queue until there is work to do.
            self._process_input_queue()

            local_unfinished_reqs = self.scheduler.has_unfinished_requests()

            if local_unfinished_reqs:
                # 2) Step the engine core.
                self._process_engine_step()

                # Check if we have now finished all requests.
                local_unfinished_reqs = (
                    self.scheduler.has_unfinished_requests())
            else:
                if self.scheduler.has_finished_requests():
                    # There are no unfinished requests, but there are some
                    # finished requests remaining to be removed from the
                    # batch state. This engine step won't perform a forward
                    # pass but will flush the finished requests to ensure
                    # up-to-date state is returned in the engine outputs.
                    self._process_engine_step()

                if not self.global_unfinished_reqs:
                    # All engines are idle.
                    continue

                # There must be unfinished requests in DP peers, run a
                # dummy forward pass.
                self.execute_dummy_batch()

            # 3) All-reduce operation to determine global unfinished reqs.
            self.global_unfinished_reqs = self._has_global_unfinished_reqs(
                local_unfinished_reqs)

            if not self.global_unfinished_reqs:
                # Notify client that we are pausing the loop.
                self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)

    def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:

        # Optimization - only perform finish-sync all-reduce every 16 steps.
        self.counter += 1
        if self.counter != 16:
            return True
        self.counter = 0

        return ParallelConfig.has_unfinished_dp(self.dp_group,
                                                local_unfinished)