core.py 11.2 KB
Newer Older
1
import pickle
2
import queue
3
import signal
4
5
import threading
import time
6
from multiprocessing.connection import Connection
7
from typing import List, Tuple, Type
8

9
import psutil
10
11
12
13
14
15
import zmq
import zmq.asyncio
from msgspec import msgpack

from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger
16
17
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
18
from vllm.utils import get_exception_traceback, zmq_socket_ctx
19
20
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
21
                            EngineCoreProfile, EngineCoreRequest,
22
                            EngineCoreRequestType, EngineCoreRequestUnion)
23
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
24
from vllm.v1.executor.abstract import Executor
25
from vllm.v1.request import Request, RequestStatus
26
from vllm.v1.serial_utils import PickleEncoder
27
28
29
30
31
32
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
33
LOGGING_TIME_S = 5
34
35
36
37
38
39
40
41


class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
42
        executor_class: Type[Executor],
43
        log_stats: bool = False,
44
    ):
45
        assert vllm_config.model_config.runner_type != "pooling"
46
        self.log_stats = log_stats
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

        logger.info("Initializing an LLM engine (v%s) with config: %s",
                    VLLM_VERSION, vllm_config)

        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
        num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
            vllm_config.cache_config)
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

        # Setup scheduler.
        self.scheduler = Scheduler(vllm_config.scheduler_config,
                                   vllm_config.cache_config,
                                   vllm_config.lora_config)

        self._last_logging_time = time.time()

67
68
        self.mm_input_mapper_server = MMInputMapperServer(
            vllm_config.model_config)
69

70
71
    def _initialize_kv_caches(self,
                              cache_config: CacheConfig) -> Tuple[int, int]:
72
        start = time.time()
73
74
75
76
77
78
79
80
81
82
83
84
        num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
        )

        if cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = cache_config.num_gpu_blocks_override
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
            num_gpu_blocks = num_gpu_blocks_override

        num_cpu_blocks = 0
85
        self.model_executor.initialize(num_gpu_blocks)
86
87
88
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
89
90
91
92
        return num_gpu_blocks, num_cpu_blocks

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
93
94
95
96
97
98
99

        if request.mm_hashes is not None:
            # Here, if hash exists for an image, then it will be fetched
            # from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client side of the
            # MM mapper, so anything that has a hash must have a HIT cache
            # entry here as well.
100
            assert request.mm_inputs is not None
101
102
            request.mm_inputs = self.mm_input_mapper_server.process_inputs(
                request.mm_inputs, request.mm_hashes)
103

104
        req = Request.from_engine_core_request(request)
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        self.scheduler.add_request(req)

    def abort_requests(self, request_ids: List[str]):
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

    def step(self) -> List[EngineCoreOutput]:
        """Schedule, execute, and make output."""

        if not self.scheduler.has_unfinished_requests():
            return []

        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, output)
        return engine_core_outputs

129
130
131
    def shutdown(self):
        self.model_executor.shutdown()

132
    def profile(self, is_start: bool = True):
133
        self.model_executor.profile(is_start)
134

135
136
137
138
139
140
141
142

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
143
144
145
146
        ready_pipe: Connection,
        vllm_config: VllmConfig,
        executor_class: Type[Executor],
        log_stats: bool = False,
147
    ):
148
        super().__init__(vllm_config, executor_class, log_stats)
149
150
151
152
153
154

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
155
156
        self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
        self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
157
158
159
160
161
162
163
164
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
                         args=(output_path, ),
                         daemon=True).start()

        # Send Readiness signal to EngineClient.
165
        ready_pipe.send({"status": "READY"})
166
167
168
169
170

    @staticmethod
    def run_engine_core(*args, **kwargs):
        """Launch EngineCore busy loop in background process."""

171
172
173
174
175
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

176
177
178
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

179
180
181
182
183
184
185
186
187
188
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

189
        parent_process = psutil.Process().parent()
190
        engine_core = None
191
192
193
194
        try:
            engine_core = EngineCoreProc(*args, **kwargs)
            engine_core.run_busy_loop()

195
        except SystemExit:
196
197
            logger.debug("EngineCore interrupted.")

198
199
200
        except Exception:
            traceback = get_exception_traceback()
            logger.error("EngineCore hit an exception: %s", traceback)
201
            parent_process.send_signal(signal.SIGUSR1)
202

203
204
205
206
        finally:
            if engine_core is not None:
                engine_core.shutdown()

207
208
209
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

210
211
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
212
213
214
215
216
217
218
219
220
221
            # 1) Poll the input queue until there is work to do.
            if not self.scheduler.has_unfinished_requests():
                while True:
                    try:
                        req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
                        self._handle_client_request(req)
                        break
                    except queue.Empty:
                        self._log_stats()
                        logger.debug("EngineCore busy loop waiting.")
222
223
                    except BaseException:
                        raise
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

            # 2) Handle any new client requests (Abort or Add).
            while not self.input_queue.empty():
                req = self.input_queue.get_nowait()
                self._handle_client_request(req)

            # 3) Step the engine core.
            outputs = self.step()

            # 4) Put EngineCoreOutputs into the output queue.
            self.output_queue.put_nowait(outputs)

            self._log_stats()

    def _log_stats(self):
        """Log basic stats every LOGGING_TIME_S"""

241
242
243
        if not self.log_stats:
            return

244
245
246
247
248
249
250
251
252
253
254
        now = time.time()

        if now - self._last_logging_time > LOGGING_TIME_S:
            logger.info(
                "RUNNING: %s | WAITING: %s",
                len(self.scheduler.running),
                len(self.scheduler.waiting),
            )

            self._last_logging_time = now

255
    def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
256
257
258
259
        """Handle EngineCoreRequest or EngineCoreABORT from Client."""

        if isinstance(request, EngineCoreRequest):
            self.add_request(request)
260
        elif isinstance(request, EngineCoreProfile):
261
            self.model_executor.profile(request.is_start)
262
263
264
265
266
267
268
269
270
        else:
            # TODO: make an EngineCoreAbort wrapper
            assert isinstance(request, list)
            self.abort_requests(request)

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
271
        decoder_add_req = PickleEncoder()
272
        decoder_abort_req = PickleEncoder()
273

274
        with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
275
276
277
278
279
280
281
282
283
284
285
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
                request_type = type_frame.buffer
                request_data = data_frame.buffer

                # Deserialize the request data.
                if request_type == EngineCoreRequestType.ADD.value:
                    request = decoder_add_req.decode(request_data)
                elif request_type == EngineCoreRequestType.ABORT.value:
                    request = decoder_abort_req.decode(request_data)
286
287
                elif request_type == EngineCoreRequestType.PROFILE.value:
                    request = pickle.loads(request_data)
288
289
290
291
292
293
294
295
296
297
298
299
300
301
                else:
                    raise ValueError(f"Unknown RequestType: {request_type}")

                # Push to input queue for core busy loop.
                self.input_queue.put_nowait(request)

    def process_output_socket(self, output_path: str):
        """Output socket IO thread."""

        # Msgpack serialization encoding.
        encoder = msgpack.Encoder()
        # Reuse send buffer.
        buffer = bytearray()

302
        with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
303
304
305
306
307
            while True:
                engine_core_outputs = self.output_queue.get()
                outputs = EngineCoreOutputs(outputs=engine_core_outputs)
                encoder.encode_into(outputs, buffer)
                socket.send_multipart((buffer, ), copy=False)