core.py 17.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import queue
4
import signal
5
6
import threading
import time
7
from concurrent.futures import Future
8
from inspect import isclass, signature
9
from multiprocessing.connection import Connection
10
from typing import Any, Optional
11

12
import msgspec
13
import psutil
14
15
16
import zmq
import zmq.asyncio

17
from vllm.config import VllmConfig
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
21
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
22
23
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
                        zmq_socket_ctx)
24
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
25
from vllm.v1.core.scheduler import Scheduler as V1Scheduler
26
from vllm.v1.core.scheduler import SchedulerOutput
27
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
28
                            EngineCoreRequestType, UtilityOutput)
29
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
30
from vllm.v1.executor.abstract import Executor
31
from vllm.v1.outputs import ModelRunnerOutput
32
from vllm.v1.request import Request, RequestStatus
33
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
34
from vllm.v1.structured_output import StructuredOutputManager
35
36
37
38
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

39
POLLING_TIMEOUT_S = 2.5
40
41
42
43
44
45
46
47


class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
48
        executor_class: type[Executor],
49
        log_stats: bool,
50
    ):
51
        assert vllm_config.model_config.runner_type != "pooling"
52

53
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
54
55
                    VLLM_VERSION, vllm_config)

56
57
        self.log_stats = log_stats

58
59
60
61
62
        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
        num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
63
            vllm_config)
64
65
66
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

67
68
        self.structured_output_manager = StructuredOutputManager(vllm_config)

69
        # Setup scheduler.
70
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
71
72
73
74
75
76
77
78
79
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
80
81
82
83
84
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
85

86
87
88
89
90
        self.scheduler = Scheduler(
            scheduler_config=vllm_config.scheduler_config,
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            lora_config=vllm_config.lora_config,
91
            speculative_config=vllm_config.speculative_config,
92
            log_stats=self.log_stats,
93
            structured_output_manager=self.structured_output_manager,
94
        )
95

96
        # Setup MM Input Mapper.
97
        self.mm_input_cache_server = MMInputCacheServer(
98
            vllm_config.model_config)
99

100
101
102
103
104
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
105
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
106
107
108
109
110
111
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

112
    def _initialize_kv_caches(self,
113
                              vllm_config: VllmConfig) -> tuple[int, int]:
114
        start = time.time()
115

116
        # Get all kv cache needed by the model
117
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
118
119
120

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
121
        available_gpu_memory = self.model_executor.determine_available_memory()
122

123
        # Get the kv cache tensor size
124
125
126
127
128
129
130
131
        kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
                                                available_gpu_memory)
        num_gpu_blocks_set = set(config.num_blocks
                                 for config in kv_cache_configs)
        assert len(num_gpu_blocks_set) == 1, (
            f"num_gpu_blocks need to be the same across workers, "
            f"but they are different: {num_gpu_blocks_set}")
        num_gpu_blocks = num_gpu_blocks_set.pop()
132
        num_cpu_blocks = 0
133
134

        # Initialize kv cache and warmup the execution
135
        self.model_executor.initialize_from_config(kv_cache_configs)
136

137
138
139
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
140
141
142
143
        return num_gpu_blocks, num_cpu_blocks

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
144
145

        if request.mm_hashes is not None:
146
147
148
149
150
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
151
            assert request.mm_inputs is not None
152
            request.mm_inputs = self.mm_input_cache_server.get_and_update(
153
                request.mm_inputs, request.mm_hashes)
154

155
        req = Request.from_engine_core_request(request)
156
157
        if req.use_structured_output:
            # Start grammar compilation asynchronously
158
            self.structured_output_manager.grammar_init(req)
159

160
161
        self.scheduler.add_request(req)

162
    def abort_requests(self, request_ids: list[str]):
163
164
165
166
167
168
169
170
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

171
    def step(self) -> EngineCoreOutputs:
172
173
        """Schedule, execute, and make output."""

174
175
176
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
177
            return EngineCoreOutputs(
178
179
180
                outputs=[],
                scheduler_stats=self.scheduler.make_stats(),
            )
181
182
183
        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
184
            scheduler_output, output)  # type: ignore
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        return engine_core_outputs

    def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
        1. Try to schedule a new batch if there are unscheduled requests
        and the job queue is not full. If a new batch is scheduled, directly
        return an empty engine core output. In other words, we won't check
        and return model outputs before the batch queue is full.
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
        # If there are unscheduled requests and the job queue
        # is not full, schedule a new batch. Note that this is not blocking.
        if (self.scheduler.get_num_unscheduled_requests() > 0
                and not self.batch_queue.full()):
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

216
217
218
219
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
220
        # block until the first batch in the job queue is finished.
221
222
223
224
225
226
227
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
            # Blocking until the first result is available.
            model_output = future.result()
            self.batch_queue.task_done()
            engine_core_outputs = self.scheduler.update_from_output(
                scheduler_output, model_output)
228

229
230
        return engine_core_outputs

231
232
233
    def shutdown(self):
        self.model_executor.shutdown()

234
    def profile(self, is_start: bool = True):
235
        self.model_executor.profile(is_start)
236

237
238
239
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

240
241
242
243
244
245
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

    def wake_up(self):
        self.model_executor.wake_up()

246
247
248
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

249
250
251
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

252
253
254
255
256
257
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

258
    def list_loras(self) -> set[int]:
259
260
261
262
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
263

264
265
266
267
268
269
270
271

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
272
273
        ready_pipe: Connection,
        vllm_config: VllmConfig,
274
        executor_class: type[Executor],
275
        log_stats: bool,
276
    ):
277
        super().__init__(vllm_config, executor_class, log_stats)
278
279
280
281
282
283

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
284
        self.input_queue: queue.Queue[tuple[EngineCoreRequestType,
285
                                            Any]] = queue.Queue()
286
        self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
287
288
289
290
291
292
293
294
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
                         args=(output_path, ),
                         daemon=True).start()

        # Send Readiness signal to EngineClient.
295
        ready_pipe.send({"status": "READY"})
296
297
298
299
300

    @staticmethod
    def run_engine_core(*args, **kwargs):
        """Launch EngineCore busy loop in background process."""

301
302
303
304
305
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

306
307
308
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

309
310
311
312
313
314
315
316
317
318
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

319
        parent_process = psutil.Process().parent()
320
        engine_core = None
321
322
323
324
        try:
            engine_core = EngineCoreProc(*args, **kwargs)
            engine_core.run_busy_loop()

325
        except SystemExit:
326
327
            logger.debug("EngineCore interrupted.")

328
329
330
        except Exception:
            traceback = get_exception_traceback()
            logger.error("EngineCore hit an exception: %s", traceback)
331
            parent_process.send_signal(signal.SIGUSR1)
332

333
334
335
336
        finally:
            if engine_core is not None:
                engine_core.shutdown()

337
338
339
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

340
341
342
        step_fn = (self.step
                   if self.batch_queue is None else self.step_with_batch_queue)

343
344
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
345
            # 1) Poll the input queue until there is work to do.
346
            while not self.scheduler.has_requests():
347
348
349
                logger.debug("EngineCore busy loop waiting.")
                req = self.input_queue.get()
                self._handle_client_request(*req)
350

351
            # 2) Handle any new client requests.
352
353
            while not self.input_queue.empty():
                req = self.input_queue.get_nowait()
354
                self._handle_client_request(*req)
355
356

            # 3) Step the engine core.
357
            outputs = step_fn()
358

359
360
361
            # 4) Put EngineCoreOutputs into the output queue.
            if outputs is not None:
                self.output_queue.put_nowait(outputs)
362

363
364
365
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
366

367
        if request_type == EngineCoreRequestType.ADD:
368
            self.add_request(request)
369
        elif request_type == EngineCoreRequestType.ABORT:
370
            self.abort_requests(request)
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        elif request_type == EngineCoreRequestType.UTILITY:
            call_id, method_name, args = request
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
                EngineCoreOutputs(utility_output=output))

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
398
399
400
401
402

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
403
404
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
405

406
        with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
407
408
409
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
410
                request_type = EngineCoreRequestType(bytes(type_frame.buffer))
411
412

                # Deserialize the request data.
413
414
415
                decoder = add_request_decoder if (
                    request_type
                    == EngineCoreRequestType.ADD) else generic_decoder
416
                request = decoder.decode(data_frame.buffer)
417
418

                # Push to input queue for core busy loop.
419
                self.input_queue.put_nowait((request_type, request))
420
421
422
423
424

    def process_output_socket(self, output_path: str):
        """Output socket IO thread."""

        # Msgpack serialization encoding.
425
        encoder = MsgpackEncoder()
426
427
428
        # Reuse send buffer.
        buffer = bytearray()

429
        with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
430
            while True:
431
                outputs = self.output_queue.get()
432
433
                encoder.encode_into(outputs, buffer)
                socket.send_multipart((buffer, ), copy=False)