core.py 11.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import queue
4
import signal
5
6
import threading
import time
7
from multiprocessing.connection import Connection
8
from typing import Any, List, Tuple, Type
9

10
import psutil
11
12
13
import zmq
import zmq.asyncio

14
from vllm.config import VllmConfig
15
from vllm.logger import init_logger
16
from vllm.lora.request import LoRARequest
17
18
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
19
from vllm.utils import get_exception_traceback, zmq_socket_ctx
20
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
21
from vllm.v1.core.scheduler import Scheduler
22
23
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
                            EngineCoreRequestType)
24
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
25
from vllm.v1.executor.abstract import Executor
26
from vllm.v1.request import Request, RequestStatus
27
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
28
29
30
31
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

32
POLLING_TIMEOUT_S = 2.5
33
34
35
36
37
38
39
40


class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
41
        executor_class: Type[Executor],
42
        log_stats: bool,
43
    ):
44
        assert vllm_config.model_config.runner_type != "pooling"
45

46
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
47
48
                    VLLM_VERSION, vllm_config)

49
50
        self.log_stats = log_stats

51
52
53
54
55
        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
        num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
56
            vllm_config)
57
58
59
60
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

        # Setup scheduler.
61
62
63
64
65
        self.scheduler = Scheduler(
            scheduler_config=vllm_config.scheduler_config,
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            lora_config=vllm_config.lora_config,
66
            log_stats=self.log_stats,
67
        )
68

69
        self.mm_input_cache_server = MMInputCacheServer(
70
            vllm_config.model_config)
71

72
    def _initialize_kv_caches(self,
73
                              vllm_config: VllmConfig) -> Tuple[int, int]:
74
        start = time.time()
75

76
        # Get all kv cache needed by the model
77
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
78
79
80

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
81
        available_gpu_memory = self.model_executor.determine_available_memory()
82

83
        # Get the kv cache tensor size
84
85
86
87
88
89
90
91
        kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
                                                available_gpu_memory)
        num_gpu_blocks_set = set(config.num_blocks
                                 for config in kv_cache_configs)
        assert len(num_gpu_blocks_set) == 1, (
            f"num_gpu_blocks need to be the same across workers, "
            f"but they are different: {num_gpu_blocks_set}")
        num_gpu_blocks = num_gpu_blocks_set.pop()
92
        num_cpu_blocks = 0
93
94

        # Initialize kv cache and warmup the execution
95
        self.model_executor.initialize(kv_cache_configs)
96

97
98
99
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
100
101
102
103
        return num_gpu_blocks, num_cpu_blocks

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
104
105

        if request.mm_hashes is not None:
106
107
108
109
110
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
111
            assert request.mm_inputs is not None
112
            request.mm_inputs = self.mm_input_cache_server.get_and_update(
113
                request.mm_inputs, request.mm_hashes)
114

115
        req = Request.from_engine_core_request(request)
116

117
118
119
120
121
122
123
124
125
126
127
        self.scheduler.add_request(req)

    def abort_requests(self, request_ids: List[str]):
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

128
    def step(self) -> EngineCoreOutputs:
129
130
131
        """Schedule, execute, and make output."""

        if not self.scheduler.has_unfinished_requests():
132
133
            return EngineCoreOutputs(
                outputs=[], scheduler_stats=self.scheduler.make_stats())
134
135
136
137
138
139
140

        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, output)
        return engine_core_outputs

141
142
143
    def shutdown(self):
        self.model_executor.shutdown()

144
    def profile(self, is_start: bool = True):
145
        self.model_executor.profile(is_start)
146

147
148
149
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

150
151
152
    def add_lora(self, lora_request: LoRARequest) -> None:
        self.model_executor.add_lora(lora_request)

153
154
155
156
157
158
159
160

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
161
162
163
        ready_pipe: Connection,
        vllm_config: VllmConfig,
        executor_class: Type[Executor],
164
        log_stats: bool,
165
    ):
166
        super().__init__(vllm_config, executor_class, log_stats)
167
168
169
170
171
172

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
173
174
        self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
                                            Any]] = queue.Queue()
175
        self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
176
177
178
179
180
181
182
183
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
                         args=(output_path, ),
                         daemon=True).start()

        # Send Readiness signal to EngineClient.
184
        ready_pipe.send({"status": "READY"})
185
186
187
188
189

    @staticmethod
    def run_engine_core(*args, **kwargs):
        """Launch EngineCore busy loop in background process."""

190
191
192
193
194
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

195
196
197
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

198
199
200
201
202
203
204
205
206
207
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

208
        parent_process = psutil.Process().parent()
209
        engine_core = None
210
211
212
213
        try:
            engine_core = EngineCoreProc(*args, **kwargs)
            engine_core.run_busy_loop()

214
        except SystemExit:
215
216
            logger.debug("EngineCore interrupted.")

217
218
219
        except Exception:
            traceback = get_exception_traceback()
            logger.error("EngineCore hit an exception: %s", traceback)
220
            parent_process.send_signal(signal.SIGUSR1)
221

222
223
224
225
        finally:
            if engine_core is not None:
                engine_core.shutdown()

226
227
228
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

229
230
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
231
232
233
234
235
            # 1) Poll the input queue until there is work to do.
            if not self.scheduler.has_unfinished_requests():
                while True:
                    try:
                        req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
236
                        self._handle_client_request(*req)
237
238
239
                        break
                    except queue.Empty:
                        logger.debug("EngineCore busy loop waiting.")
240
241
242
                        # Break out the loop so we can log_stats in step().
                        if self.log_stats:
                            break
243
244
                    except BaseException:
                        raise
245

246
            # 2) Handle any new client requests.
247
248
            while not self.input_queue.empty():
                req = self.input_queue.get_nowait()
249
                self._handle_client_request(*req)
250
251
252
253

            # 3) Step the engine core.
            outputs = self.step()

254
            # 5) Put EngineCoreOutputs into the output queue.
255
256
            self.output_queue.put_nowait(outputs)

257
258
259
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
260

261
        if request_type == EngineCoreRequestType.ADD:
262
            self.add_request(request)
263
        elif request_type == EngineCoreRequestType.ABORT:
264
            self.abort_requests(request)
265
266
267
268
        elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
            self.reset_prefix_cache()
        elif request_type == EngineCoreRequestType.PROFILE:
            self.model_executor.profile(request)
269
270
        elif request_type == EngineCoreRequestType.ADD_LORA:
            self.model_executor.add_lora(request)
271
272
273
274
275

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
276
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
277
        add_lora_decoder = MsgpackDecoder(LoRARequest)
278
        generic_decoder = MsgpackDecoder()
279

280
        with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
281
282
283
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
284
                request_type = EngineCoreRequestType(bytes(type_frame.buffer))
285
286

                # Deserialize the request data.
287
288
289
290
291
292
293
294
                decoder = None
                if request_type == EngineCoreRequestType.ADD:
                    decoder = add_request_decoder
                elif request_type == EngineCoreRequestType.ADD_LORA:
                    decoder = add_lora_decoder
                else:
                    decoder = generic_decoder

295
                request = decoder.decode(data_frame.buffer)
296
297

                # Push to input queue for core busy loop.
298
                self.input_queue.put_nowait((request_type, request))
299
300
301
302
303

    def process_output_socket(self, output_path: str):
        """Output socket IO thread."""

        # Msgpack serialization encoding.
304
        encoder = MsgpackEncoder()
305
306
307
        # Reuse send buffer.
        buffer = bytearray()

308
        with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
309
            while True:
310
                outputs = self.output_queue.get()
311
312
                encoder.encode_into(outputs, buffer)
                socket.send_multipart((buffer, ), copy=False)