core.py 16.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import queue
4
import signal
5
6
import threading
import time
7
from concurrent.futures import Future
8
from inspect import isclass, signature
9
from multiprocessing.connection import Connection
10
from typing import Any, Optional
11

12
import msgspec
13
import psutil
14
15
16
import zmq
import zmq.asyncio

17
from vllm.config import VllmConfig
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
21
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
22
from vllm.utils import get_exception_traceback, zmq_socket_ctx
23
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
24
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
25
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
26
                            EngineCoreRequestType, UtilityOutput)
27
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
28
from vllm.v1.executor.abstract import Executor
29
from vllm.v1.outputs import ModelRunnerOutput
30
from vllm.v1.request import Request, RequestStatus
31
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
32
from vllm.v1.structured_output import StructuredOutputManager
33
34
35
36
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

37
POLLING_TIMEOUT_S = 2.5
38
39
40
41
42
43
44
45


class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
46
        executor_class: type[Executor],
47
        log_stats: bool,
48
    ):
49
        assert vllm_config.model_config.runner_type != "pooling"
50

51
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
52
53
                    VLLM_VERSION, vllm_config)

54
55
        self.log_stats = log_stats

56
57
58
59
60
        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
        num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
61
            vllm_config)
62
63
64
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

65
66
        self.structured_output_manager = StructuredOutputManager(vllm_config)

67
        # Setup scheduler.
68
69
70
71
72
        self.scheduler = Scheduler(
            scheduler_config=vllm_config.scheduler_config,
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            lora_config=vllm_config.lora_config,
73
            speculative_config=vllm_config.speculative_config,
74
            log_stats=self.log_stats,
75
            structured_output_manager=self.structured_output_manager,
76
        )
77

78
        # Setup MM Input Mapper.
79
        self.mm_input_cache_server = MMInputCacheServer(
80
            vllm_config.model_config)
81

82
83
84
85
86
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
87
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
88
89
90
91
92
93
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

94
    def _initialize_kv_caches(self,
95
                              vllm_config: VllmConfig) -> tuple[int, int]:
96
        start = time.time()
97

98
        # Get all kv cache needed by the model
99
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
100
101
102

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
103
        available_gpu_memory = self.model_executor.determine_available_memory()
104

105
        # Get the kv cache tensor size
106
107
108
109
110
111
112
113
        kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
                                                available_gpu_memory)
        num_gpu_blocks_set = set(config.num_blocks
                                 for config in kv_cache_configs)
        assert len(num_gpu_blocks_set) == 1, (
            f"num_gpu_blocks need to be the same across workers, "
            f"but they are different: {num_gpu_blocks_set}")
        num_gpu_blocks = num_gpu_blocks_set.pop()
114
        num_cpu_blocks = 0
115
116

        # Initialize kv cache and warmup the execution
117
        self.model_executor.initialize_from_config(kv_cache_configs)
118

119
120
121
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
122
123
124
125
        return num_gpu_blocks, num_cpu_blocks

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
126
127

        if request.mm_hashes is not None:
128
129
130
131
132
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
133
            assert request.mm_inputs is not None
134
            request.mm_inputs = self.mm_input_cache_server.get_and_update(
135
                request.mm_inputs, request.mm_hashes)
136

137
        req = Request.from_engine_core_request(request)
138
139
140
        if req.use_structured_output:
            # Start grammar compilation asynchronously
            self.structured_output_manager.populate_cache(req)
141

142
143
        self.scheduler.add_request(req)

144
    def abort_requests(self, request_ids: list[str]):
145
146
147
148
149
150
151
152
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

153
    def step(self) -> EngineCoreOutputs:
154
155
156
        """Schedule, execute, and make output."""

        if not self.scheduler.has_unfinished_requests():
157
            return EngineCoreOutputs(
158
159
160
                outputs=[],
                scheduler_stats=self.scheduler.make_stats(),
            )
161
        scheduler_output = self.scheduler.schedule()
162
163
164
165
166
167
168
169
170
171

        # This case may occur when the only unfinished requests are
        # structured output requests where the grammar has not finished
        # compiling yet, so there's nothing to run.
        if scheduler_output.total_num_scheduled_tokens == 0:
            return EngineCoreOutputs(
                outputs=[],
                scheduler_stats=self.scheduler.make_stats(),
            )

172
173
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
174
            scheduler_output, output)  # type: ignore
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        return engine_core_outputs

    def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
        1. Try to schedule a new batch if there are unscheduled requests
        and the job queue is not full. If a new batch is scheduled, directly
        return an empty engine core output. In other words, we won't check
        and return model outputs before the batch queue is full.
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
        # If there are unscheduled requests and the job queue
        # is not full, schedule a new batch. Note that this is not blocking.
        if (self.scheduler.get_num_unscheduled_requests() > 0
                and not self.batch_queue.full()):
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

        # If all requests are scheduled or the job queue is full,
        # block until the first batch in the job queue is finished.
        if (scheduler_output is None
                or scheduler_output.total_num_scheduled_tokens == 0):
            try:
                future, scheduler_output = self.batch_queue.get(
                    timeout=POLLING_TIMEOUT_S)
                # Blocking until the first result is available.
                model_output = future.result()
                self.batch_queue.task_done()
                engine_core_outputs = self.scheduler.update_from_output(
                    scheduler_output, model_output)
            except queue.Empty:
                # If the queue is empty (timeout at .get), return
                # an empty EngineCoreOutputs for logging.
                engine_core_outputs = EngineCoreOutputs(
                    outputs=[], scheduler_stats=self.scheduler.make_stats())

224
225
        return engine_core_outputs

226
227
228
    def shutdown(self):
        self.model_executor.shutdown()

229
    def profile(self, is_start: bool = True):
230
        self.model_executor.profile(is_start)
231

232
233
234
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

235
236
237
238
239
240
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

    def wake_up(self):
        self.model_executor.wake_up()

241
242
243
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

244
245
246
247
248
249
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

250
    def list_loras(self) -> set[int]:
251
252
253
254
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
255

256
257
258
259
260
261
262
263

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
264
265
        ready_pipe: Connection,
        vllm_config: VllmConfig,
266
        executor_class: type[Executor],
267
        log_stats: bool,
268
    ):
269
        super().__init__(vllm_config, executor_class, log_stats)
270
271
272
273
274
275

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
276
        self.input_queue: queue.Queue[tuple[EngineCoreRequestType,
277
                                            Any]] = queue.Queue()
278
        self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
279
280
281
282
283
284
285
286
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
                         args=(output_path, ),
                         daemon=True).start()

        # Send Readiness signal to EngineClient.
287
        ready_pipe.send({"status": "READY"})
288
289
290
291
292

    @staticmethod
    def run_engine_core(*args, **kwargs):
        """Launch EngineCore busy loop in background process."""

293
294
295
296
297
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

298
299
300
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

301
302
303
304
305
306
307
308
309
310
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

311
        parent_process = psutil.Process().parent()
312
        engine_core = None
313
314
315
316
        try:
            engine_core = EngineCoreProc(*args, **kwargs)
            engine_core.run_busy_loop()

317
        except SystemExit:
318
319
            logger.debug("EngineCore interrupted.")

320
321
322
        except Exception:
            traceback = get_exception_traceback()
            logger.error("EngineCore hit an exception: %s", traceback)
323
            parent_process.send_signal(signal.SIGUSR1)
324

325
326
327
328
        finally:
            if engine_core is not None:
                engine_core.shutdown()

329
330
331
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

332
333
334
        step_fn = (self.step
                   if self.batch_queue is None else self.step_with_batch_queue)

335
336
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
337
            # 1) Poll the input queue until there is work to do.
338
339
340
341
            while not self.scheduler.has_unfinished_requests():
                logger.debug("EngineCore busy loop waiting.")
                req = self.input_queue.get()
                self._handle_client_request(*req)
342

343
            # 2) Handle any new client requests.
344
345
            while not self.input_queue.empty():
                req = self.input_queue.get_nowait()
346
                self._handle_client_request(*req)
347
348

            # 3) Step the engine core.
349
            outputs = step_fn()
350

351
352
353
            # 4) Put EngineCoreOutputs into the output queue.
            if outputs is not None:
                self.output_queue.put_nowait(outputs)
354

355
356
357
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
358

359
        if request_type == EngineCoreRequestType.ADD:
360
            self.add_request(request)
361
        elif request_type == EngineCoreRequestType.ABORT:
362
            self.abort_requests(request)
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        elif request_type == EngineCoreRequestType.UTILITY:
            call_id, method_name, args = request
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
                EngineCoreOutputs(utility_output=output))

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
390
391
392
393
394

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
395
396
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
397

398
        with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
399
400
401
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
402
                request_type = EngineCoreRequestType(bytes(type_frame.buffer))
403
404

                # Deserialize the request data.
405
406
407
                decoder = add_request_decoder if (
                    request_type
                    == EngineCoreRequestType.ADD) else generic_decoder
408
                request = decoder.decode(data_frame.buffer)
409
410

                # Push to input queue for core busy loop.
411
                self.input_queue.put_nowait((request_type, request))
412
413
414
415
416

    def process_output_socket(self, output_path: str):
        """Output socket IO thread."""

        # Msgpack serialization encoding.
417
        encoder = MsgpackEncoder()
418
419
420
        # Reuse send buffer.
        buffer = bytearray()

421
        with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
422
            while True:
423
                outputs = self.output_queue.get()
424
425
                encoder.encode_into(outputs, buffer)
                socket.send_multipart((buffer, ), copy=False)