core.py 15 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import queue
4
import signal
5
6
import threading
import time
7
from concurrent.futures import Future
8
from multiprocessing.connection import Connection
9
from typing import Any, List, Optional, Tuple, Type
10

11
import psutil
12
13
14
import zmq
import zmq.asyncio

15
from vllm.config import VllmConfig
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
19
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
20
from vllm.utils import get_exception_traceback, zmq_socket_ctx
21
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
22
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
23
24
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
                            EngineCoreRequestType)
25
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
26
from vllm.v1.executor.abstract import Executor
27
from vllm.v1.outputs import ModelRunnerOutput
28
from vllm.v1.request import Request, RequestStatus
29
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
30
31
32
33
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

34
POLLING_TIMEOUT_S = 2.5
35
36
37
38
39
40
41
42


class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
43
        executor_class: Type[Executor],
44
        log_stats: bool,
45
    ):
46
        assert vllm_config.model_config.runner_type != "pooling"
47

48
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
49
50
                    VLLM_VERSION, vllm_config)

51
52
        self.log_stats = log_stats

53
54
55
56
57
        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
        num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
58
            vllm_config)
59
60
61
62
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

        # Setup scheduler.
63
64
65
66
67
        self.scheduler = Scheduler(
            scheduler_config=vllm_config.scheduler_config,
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            lora_config=vllm_config.lora_config,
68
            log_stats=self.log_stats,
69
        )
70

71
        # Setup MM Input Mapper.
72
        self.mm_input_cache_server = MMInputCacheServer(
73
            vllm_config.model_config)
74

75
76
77
78
79
80
81
82
83
84
85
86
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
        self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput],
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

87
    def _initialize_kv_caches(self,
88
                              vllm_config: VllmConfig) -> Tuple[int, int]:
89
        start = time.time()
90

91
        # Get all kv cache needed by the model
92
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
93
94
95

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
96
        available_gpu_memory = self.model_executor.determine_available_memory()
97

98
        # Get the kv cache tensor size
99
100
101
102
103
104
105
106
        kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
                                                available_gpu_memory)
        num_gpu_blocks_set = set(config.num_blocks
                                 for config in kv_cache_configs)
        assert len(num_gpu_blocks_set) == 1, (
            f"num_gpu_blocks need to be the same across workers, "
            f"but they are different: {num_gpu_blocks_set}")
        num_gpu_blocks = num_gpu_blocks_set.pop()
107
        num_cpu_blocks = 0
108
109

        # Initialize kv cache and warmup the execution
110
        self.model_executor.initialize(kv_cache_configs)
111

112
113
114
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
115
116
117
118
        return num_gpu_blocks, num_cpu_blocks

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
119
120

        if request.mm_hashes is not None:
121
122
123
124
125
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
126
            assert request.mm_inputs is not None
127
            request.mm_inputs = self.mm_input_cache_server.get_and_update(
128
                request.mm_inputs, request.mm_hashes)
129

130
        req = Request.from_engine_core_request(request)
131

132
133
134
135
136
137
138
139
140
141
142
        self.scheduler.add_request(req)

    def abort_requests(self, request_ids: List[str]):
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

143
    def step(self) -> EngineCoreOutputs:
144
145
146
        """Schedule, execute, and make output."""

        if not self.scheduler.has_unfinished_requests():
147
148
            return EngineCoreOutputs(
                outputs=[], scheduler_stats=self.scheduler.make_stats())
149
150
151
152

        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
            scheduler_output, output)  # type: ignore
        return engine_core_outputs

    def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
        1. Try to schedule a new batch if there are unscheduled requests
        and the job queue is not full. If a new batch is scheduled, directly
        return an empty engine core output. In other words, we won't check
        and return model outputs before the batch queue is full.
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
        # If there are unscheduled requests and the job queue
        # is not full, schedule a new batch. Note that this is not blocking.
        if (self.scheduler.get_num_unscheduled_requests() > 0
                and not self.batch_queue.full()):
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

        # If all requests are scheduled or the job queue is full,
        # block until the first batch in the job queue is finished.
        if (scheduler_output is None
                or scheduler_output.total_num_scheduled_tokens == 0):
            try:
                future, scheduler_output = self.batch_queue.get(
                    timeout=POLLING_TIMEOUT_S)
                # Blocking until the first result is available.
                model_output = future.result()
                self.batch_queue.task_done()
                engine_core_outputs = self.scheduler.update_from_output(
                    scheduler_output, model_output)
            except queue.Empty:
                # If the queue is empty (timeout at .get), return
                # an empty EngineCoreOutputs for logging.
                engine_core_outputs = EngineCoreOutputs(
                    outputs=[], scheduler_stats=self.scheduler.make_stats())

202
203
        return engine_core_outputs

204
205
206
    def shutdown(self):
        self.model_executor.shutdown()

207
    def profile(self, is_start: bool = True):
208
        self.model_executor.profile(is_start)
209

210
211
212
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

213
214
215
    def add_lora(self, lora_request: LoRARequest) -> None:
        self.model_executor.add_lora(lora_request)

216
217
218
219
220
221
222
223

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
224
225
226
        ready_pipe: Connection,
        vllm_config: VllmConfig,
        executor_class: Type[Executor],
227
        log_stats: bool,
228
    ):
229
        super().__init__(vllm_config, executor_class, log_stats)
230
231
232
233
234
235

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
236
237
        self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
                                            Any]] = queue.Queue()
238
        self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
239
240
241
242
243
244
245
246
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
                         args=(output_path, ),
                         daemon=True).start()

        # Send Readiness signal to EngineClient.
247
        ready_pipe.send({"status": "READY"})
248
249
250
251
252

    @staticmethod
    def run_engine_core(*args, **kwargs):
        """Launch EngineCore busy loop in background process."""

253
254
255
256
257
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

258
259
260
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

261
262
263
264
265
266
267
268
269
270
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

271
        parent_process = psutil.Process().parent()
272
        engine_core = None
273
274
275
276
        try:
            engine_core = EngineCoreProc(*args, **kwargs)
            engine_core.run_busy_loop()

277
        except SystemExit:
278
279
            logger.debug("EngineCore interrupted.")

280
281
282
        except Exception:
            traceback = get_exception_traceback()
            logger.error("EngineCore hit an exception: %s", traceback)
283
            parent_process.send_signal(signal.SIGUSR1)
284

285
286
287
288
        finally:
            if engine_core is not None:
                engine_core.shutdown()

289
290
291
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

292
293
294
        step_fn = (self.step
                   if self.batch_queue is None else self.step_with_batch_queue)

295
296
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
297
298
299
300
301
            # 1) Poll the input queue until there is work to do.
            if not self.scheduler.has_unfinished_requests():
                while True:
                    try:
                        req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
302
                        self._handle_client_request(*req)
303
304
305
                        break
                    except queue.Empty:
                        logger.debug("EngineCore busy loop waiting.")
306
307
308
                        # Break out the loop so we can log_stats in step().
                        if self.log_stats:
                            break
309
310
                    except BaseException:
                        raise
311

312
            # 2) Handle any new client requests.
313
314
            while not self.input_queue.empty():
                req = self.input_queue.get_nowait()
315
                self._handle_client_request(*req)
316
317

            # 3) Step the engine core.
318
            outputs = step_fn()
319

320
321
322
            # 4) Put EngineCoreOutputs into the output queue.
            if outputs is not None:
                self.output_queue.put_nowait(outputs)
323

324
325
326
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
327

328
        if request_type == EngineCoreRequestType.ADD:
329
            self.add_request(request)
330
        elif request_type == EngineCoreRequestType.ABORT:
331
            self.abort_requests(request)
332
333
334
335
        elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
            self.reset_prefix_cache()
        elif request_type == EngineCoreRequestType.PROFILE:
            self.model_executor.profile(request)
336
337
        elif request_type == EngineCoreRequestType.ADD_LORA:
            self.model_executor.add_lora(request)
338
339
340
341
342

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
343
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
344
        add_lora_decoder = MsgpackDecoder(LoRARequest)
345
        generic_decoder = MsgpackDecoder()
346

347
        with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
348
349
350
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
351
                request_type = EngineCoreRequestType(bytes(type_frame.buffer))
352
353

                # Deserialize the request data.
354
355
356
357
358
359
360
361
                decoder = None
                if request_type == EngineCoreRequestType.ADD:
                    decoder = add_request_decoder
                elif request_type == EngineCoreRequestType.ADD_LORA:
                    decoder = add_lora_decoder
                else:
                    decoder = generic_decoder

362
                request = decoder.decode(data_frame.buffer)
363
364

                # Push to input queue for core busy loop.
365
                self.input_queue.put_nowait((request_type, request))
366
367
368
369
370

    def process_output_socket(self, output_path: str):
        """Output socket IO thread."""

        # Msgpack serialization encoding.
371
        encoder = MsgpackEncoder()
372
373
374
        # Reuse send buffer.
        buffer = bytearray()

375
        with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
376
            while True:
377
                outputs = self.output_queue.get()
378
379
                encoder.encode_into(outputs, buffer)
                socket.send_multipart((buffer, ), copy=False)