core.py 13 KB
Newer Older
1
import pickle
2
import queue
3
import signal
4
5
import threading
import time
6
from dataclasses import dataclass
7
from multiprocessing.process import BaseProcess
8
from typing import List, Tuple, Type, Union
9
10
11
12
13
14

import zmq
import zmq.asyncio
from msgspec import msgpack

from vllm.config import CacheConfig, VllmConfig
15
from vllm.executor.multiproc_worker_utils import get_mp_context
16
17
18
19
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
20
21
                            EngineCoreProfile, EngineCoreRequest,
                            EngineCoreRequestType)
22
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
23
from vllm.v1.executor.abstract import Executor
24
from vllm.v1.request import Request, RequestStatus
25
from vllm.v1.serial_utils import PickleEncoder
26
from vllm.v1.utils import make_zmq_socket
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5000


class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
42
        executor_class: Type[Executor],
43
44
        usage_context: UsageContext,
    ):
45
        assert vllm_config.model_config.runner_type != "pooling"
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        logger.info("Initializing an LLM engine (v%s) with config: %s",
                    VLLM_VERSION, vllm_config)

        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
        num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
            vllm_config.cache_config)
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

        # Setup scheduler.
        self.scheduler = Scheduler(vllm_config.scheduler_config,
                                   vllm_config.cache_config,
                                   vllm_config.lora_config)

        self._last_logging_time = time.time()

66
67
        self.mm_input_mapper_server = MMInputMapperServer()

68
69
    def _initialize_kv_caches(self,
                              cache_config: CacheConfig) -> Tuple[int, int]:
70
        start = time.time()
71
72
73
74
75
76
77
78
79
80
81
82
        num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
        )

        if cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = cache_config.num_gpu_blocks_override
            logger.info(
                "Overriding num_gpu_blocks=%d with "
                "num_gpu_blocks_override=%d", num_gpu_blocks,
                num_gpu_blocks_override)
            num_gpu_blocks = num_gpu_blocks_override

        num_cpu_blocks = 0
83
        self.model_executor.initialize(num_gpu_blocks)
84
85
86
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
87
88
89
90
        return num_gpu_blocks, num_cpu_blocks

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
91
92
93
94
95
96
97
98
99
100

        if request.mm_hashes is not None:
            # Here, if hash exists for an image, then it will be fetched
            # from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client side of the
            # MM mapper, so anything that has a hash must have a HIT cache
            # entry here as well.
            request.mm_inputs = self.mm_input_mapper_server.process_inputs(
                request.mm_inputs, request.mm_hashes)

101
        req = Request.from_engine_core_request(request)
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        self.scheduler.add_request(req)

    def abort_requests(self, request_ids: List[str]):
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

    def step(self) -> List[EngineCoreOutput]:
        """Schedule, execute, and make output."""

        if not self.scheduler.has_unfinished_requests():
            return []

        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, output)
        return engine_core_outputs

126
127
128
    def shutdown(self):
        self.model_executor.shutdown()

129
    def profile(self, is_start=True):
130
        self.model_executor.profile(is_start)
131

132

133
134
135
136
137
138
139
140
@dataclass
class EngineCoreProcHandle:
    proc: BaseProcess
    ready_path: str
    input_path: str
    output_path: str


141
142
143
144
145
146
147
148
class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    READY_STR = "READY"

    def __init__(
        self,
        vllm_config: VllmConfig,
149
        executor_class: Type[Executor],
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        usage_context: UsageContext,
        input_path: str,
        output_path: str,
        ready_path: str,
    ):
        super().__init__(vllm_config, executor_class, usage_context)

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
        self.input_queue = queue.Queue()
        self.output_queue = queue.Queue()
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
                         args=(output_path, ),
                         daemon=True).start()

        # Send Readiness signal to EngineClient.
172
        with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
            ready_socket.send_string(EngineCoreProc.READY_STR)

    @staticmethod
    def wait_for_startup(
        proc: BaseProcess,
        ready_path: str,
    ) -> None:
        """Wait until the EngineCore is ready."""

        try:
            sync_ctx = zmq.Context()  # type: ignore[attr-defined]
            socket = sync_ctx.socket(zmq.constants.PULL)
            socket.connect(ready_path)

            # Wait for EngineCore to send EngineCoreProc.READY_STR.
            while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
                logger.debug("Waiting for EngineCoreProc to startup.")

                if not proc.is_alive():
                    raise RuntimeError("EngineCoreProc failed to start.")

            message = socket.recv_string()
            assert message == EngineCoreProc.READY_STR

        except BaseException as e:
            logger.exception(e)
            raise e

        finally:
            sync_ctx.destroy(linger=0)

    @staticmethod
    def make_engine_core_process(
        vllm_config: VllmConfig,
207
        executor_class: Type[Executor],
208
209
210
211
        usage_context: UsageContext,
        input_path: str,
        output_path: str,
        ready_path: str,
212
    ) -> EngineCoreProcHandle:
213
        context = get_mp_context()
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229

        process_kwargs = {
            "input_path": input_path,
            "output_path": output_path,
            "ready_path": ready_path,
            "vllm_config": vllm_config,
            "executor_class": executor_class,
            "usage_context": usage_context,
        }
        # Run EngineCore busy loop in background process.
        proc = context.Process(target=EngineCoreProc.run_engine_core,
                               kwargs=process_kwargs)
        proc.start()

        # Wait for startup
        EngineCoreProc.wait_for_startup(proc, ready_path)
230
231
232
233
        return EngineCoreProcHandle(proc=proc,
                                    ready_path=ready_path,
                                    input_path=input_path,
                                    output_path=output_path)
234
235
236
237
238

    @staticmethod
    def run_engine_core(*args, **kwargs):
        """Launch EngineCore busy loop in background process."""

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

        engine_core = None
255
256
257
258
        try:
            engine_core = EngineCoreProc(*args, **kwargs)
            engine_core.run_busy_loop()

259
        except SystemExit:
260
261
262
263
264
265
            logger.debug("EngineCore interrupted.")

        except BaseException as e:
            logger.exception(e)
            raise e

266
267
268
269
270
        finally:
            if engine_core is not None:
                engine_core.shutdown()
                engine_core = None

271
272
273
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

274
275
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
276
277
278
279
280
281
282
283
284
285
            # 1) Poll the input queue until there is work to do.
            if not self.scheduler.has_unfinished_requests():
                while True:
                    try:
                        req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
                        self._handle_client_request(req)
                        break
                    except queue.Empty:
                        self._log_stats()
                        logger.debug("EngineCore busy loop waiting.")
286
287
                    except BaseException:
                        raise
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

            # 2) Handle any new client requests (Abort or Add).
            while not self.input_queue.empty():
                req = self.input_queue.get_nowait()
                self._handle_client_request(req)

            # 3) Step the engine core.
            outputs = self.step()

            # 4) Put EngineCoreOutputs into the output queue.
            self.output_queue.put_nowait(outputs)

            self._log_stats()

    def _log_stats(self):
        """Log basic stats every LOGGING_TIME_S"""

        now = time.time()

        if now - self._last_logging_time > LOGGING_TIME_S:
            logger.info(
                "RUNNING: %s | WAITING: %s",
                len(self.scheduler.running),
                len(self.scheduler.waiting),
            )

            self._last_logging_time = now

    def _handle_client_request(
317
318
        self, request: Union[EngineCoreRequest, EngineCoreProfile,
                             List[str]]) -> None:
319
320
321
322
        """Handle EngineCoreRequest or EngineCoreABORT from Client."""

        if isinstance(request, EngineCoreRequest):
            self.add_request(request)
323
324
        elif isinstance(request, EngineCoreProfile):
            self.model_executor.worker.profile(request.is_start)
325
326
327
328
329
330
331
332
333
        else:
            # TODO: make an EngineCoreAbort wrapper
            assert isinstance(request, list)
            self.abort_requests(request)

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
334
        decoder_add_req = PickleEncoder()
335
        decoder_abort_req = PickleEncoder()
336

337
        with make_zmq_socket(input_path, zmq.constants.PULL) as socket:
338
339
340
341
342
343
344
345
346
347
348
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
                request_type = type_frame.buffer
                request_data = data_frame.buffer

                # Deserialize the request data.
                if request_type == EngineCoreRequestType.ADD.value:
                    request = decoder_add_req.decode(request_data)
                elif request_type == EngineCoreRequestType.ABORT.value:
                    request = decoder_abort_req.decode(request_data)
349
350
                elif request_type == EngineCoreRequestType.PROFILE.value:
                    request = pickle.loads(request_data)
351
352
353
354
355
356
357
358
359
360
361
362
363
364
                else:
                    raise ValueError(f"Unknown RequestType: {request_type}")

                # Push to input queue for core busy loop.
                self.input_queue.put_nowait(request)

    def process_output_socket(self, output_path: str):
        """Output socket IO thread."""

        # Msgpack serialization encoding.
        encoder = msgpack.Encoder()
        # Reuse send buffer.
        buffer = bytearray()

365
        with make_zmq_socket(output_path, zmq.constants.PUSH) as socket:
366
367
368
369
370
            while True:
                engine_core_outputs = self.output_queue.get()
                outputs = EngineCoreOutputs(outputs=engine_core_outputs)
                encoder.encode_into(outputs, buffer)
                socket.send_multipart((buffer, ), copy=False)