core.py 16.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import queue
4
import signal
5
6
import threading
import time
7
from concurrent.futures import Future
8
from multiprocessing.connection import Connection
9
from typing import Any, List, Optional, Tuple, Type
10

11
import psutil
12
13
14
import zmq
import zmq.asyncio

15
from vllm.config import VllmConfig
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
19
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
20
from vllm.utils import get_exception_traceback, zmq_socket_ctx
21
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
22
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
23
24
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
                            EngineCoreRequestType)
25
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
26
from vllm.v1.executor.abstract import Executor
27
from vllm.v1.outputs import ModelRunnerOutput
28
from vllm.v1.request import Request, RequestStatus
29
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
30
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
31
32
33
34
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

35
POLLING_TIMEOUT_S = 2.5
36
37
38
39
40
41
42
43


class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
44
        executor_class: Type[Executor],
45
        log_stats: bool,
46
    ):
47
        assert vllm_config.model_config.runner_type != "pooling"
48

49
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
50
51
                    VLLM_VERSION, vllm_config)

52
53
        self.log_stats = log_stats

54
55
56
57
58
        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
        num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
59
            vllm_config)
60
61
62
63
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

        # Setup scheduler.
64
65
66
67
68
        self.scheduler = Scheduler(
            scheduler_config=vllm_config.scheduler_config,
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            lora_config=vllm_config.lora_config,
69
            speculative_config=vllm_config.speculative_config,
70
            log_stats=self.log_stats,
71
        )
72

73
        # Setup MM Input Mapper.
74
        self.mm_input_cache_server = MMInputCacheServer(
75
            vllm_config.model_config)
76

77
78
79
80
81
82
83
84
85
86
87
88
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
        self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput],
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

89
90
91
92
93
94
95
96
97
        # Setup speculative decode.
        # TODO: find a better way to check if we are using ngram.
        self.use_spec_decode = False
        if self.scheduler.speculative_config:
            assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
                    , "Only ngram spec decode is supported in V1."
            self.proposer = NgramProposer()
            self.use_spec_decode = True

98
    def _initialize_kv_caches(self,
99
                              vllm_config: VllmConfig) -> Tuple[int, int]:
100
        start = time.time()
101

102
        # Get all kv cache needed by the model
103
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
104
105
106

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
107
        available_gpu_memory = self.model_executor.determine_available_memory()
108

109
        # Get the kv cache tensor size
110
111
112
113
114
115
116
117
        kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
                                                available_gpu_memory)
        num_gpu_blocks_set = set(config.num_blocks
                                 for config in kv_cache_configs)
        assert len(num_gpu_blocks_set) == 1, (
            f"num_gpu_blocks need to be the same across workers, "
            f"but they are different: {num_gpu_blocks_set}")
        num_gpu_blocks = num_gpu_blocks_set.pop()
118
        num_cpu_blocks = 0
119
120

        # Initialize kv cache and warmup the execution
121
        self.model_executor.initialize(kv_cache_configs)
122

123
124
125
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
126
127
128
129
        return num_gpu_blocks, num_cpu_blocks

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
130
131

        if request.mm_hashes is not None:
132
133
134
135
136
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
137
            assert request.mm_inputs is not None
138
            request.mm_inputs = self.mm_input_cache_server.get_and_update(
139
                request.mm_inputs, request.mm_hashes)
140

141
        req = Request.from_engine_core_request(request)
142

143
144
145
146
147
148
149
150
151
152
153
        self.scheduler.add_request(req)

    def abort_requests(self, request_ids: List[str]):
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

154
    def step(self) -> EngineCoreOutputs:
155
156
157
        """Schedule, execute, and make output."""

        if not self.scheduler.has_unfinished_requests():
158
159
            return EngineCoreOutputs(
                outputs=[], scheduler_stats=self.scheduler.make_stats())
160

161
162
163
        if self.use_spec_decode:
            self.propose_tokens()

164
165
166
        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
            scheduler_output, output)  # type: ignore
        return engine_core_outputs

    def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
        1. Try to schedule a new batch if there are unscheduled requests
        and the job queue is not full. If a new batch is scheduled, directly
        return an empty engine core output. In other words, we won't check
        and return model outputs before the batch queue is full.
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
        # If there are unscheduled requests and the job queue
        # is not full, schedule a new batch. Note that this is not blocking.
        if (self.scheduler.get_num_unscheduled_requests() > 0
                and not self.batch_queue.full()):
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

        # If all requests are scheduled or the job queue is full,
        # block until the first batch in the job queue is finished.
        if (scheduler_output is None
                or scheduler_output.total_num_scheduled_tokens == 0):
            try:
                future, scheduler_output = self.batch_queue.get(
                    timeout=POLLING_TIMEOUT_S)
                # Blocking until the first result is available.
                model_output = future.result()
                self.batch_queue.task_done()
                engine_core_outputs = self.scheduler.update_from_output(
                    scheduler_output, model_output)
            except queue.Empty:
                # If the queue is empty (timeout at .get), return
                # an empty EngineCoreOutputs for logging.
                engine_core_outputs = EngineCoreOutputs(
                    outputs=[], scheduler_stats=self.scheduler.make_stats())

216
217
        return engine_core_outputs

218
219
220
    def shutdown(self):
        self.model_executor.shutdown()

221
    def profile(self, is_start: bool = True):
222
        self.model_executor.profile(is_start)
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    def propose_tokens(self):
        assert self.scheduler.speculative_config is not None
        for req in self.scheduler.running:
            # Ignore requests that are doing chunked prefill.
            if req.num_computed_tokens < req.num_tokens - 1:
                continue
            # Ignore requests that already have spec tokens.
            if req.spec_token_ids:
                continue
            spec_tokens = self.proposer.propose(
                req.all_token_ids,
                self.scheduler.speculative_config.ngram_prompt_lookup_min,
                self.scheduler.speculative_config.num_speculative_tokens,
            )
            if spec_tokens:
                req.append_spec_token_ids(spec_tokens)

241
242
243
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

244
245
246
    def add_lora(self, lora_request: LoRARequest) -> None:
        self.model_executor.add_lora(lora_request)

247
248
249
250
251
252
253
254

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
255
256
257
        ready_pipe: Connection,
        vllm_config: VllmConfig,
        executor_class: Type[Executor],
258
        log_stats: bool,
259
    ):
260
        super().__init__(vllm_config, executor_class, log_stats)
261
262
263
264
265
266

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
267
268
        self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
                                            Any]] = queue.Queue()
269
        self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
270
271
272
273
274
275
276
277
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
                         args=(output_path, ),
                         daemon=True).start()

        # Send Readiness signal to EngineClient.
278
        ready_pipe.send({"status": "READY"})
279
280
281
282
283

    @staticmethod
    def run_engine_core(*args, **kwargs):
        """Launch EngineCore busy loop in background process."""

284
285
286
287
288
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

289
290
291
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

292
293
294
295
296
297
298
299
300
301
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

302
        parent_process = psutil.Process().parent()
303
        engine_core = None
304
305
306
307
        try:
            engine_core = EngineCoreProc(*args, **kwargs)
            engine_core.run_busy_loop()

308
        except SystemExit:
309
310
            logger.debug("EngineCore interrupted.")

311
312
313
        except Exception:
            traceback = get_exception_traceback()
            logger.error("EngineCore hit an exception: %s", traceback)
314
            parent_process.send_signal(signal.SIGUSR1)
315

316
317
318
319
        finally:
            if engine_core is not None:
                engine_core.shutdown()

320
321
322
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

323
324
325
        step_fn = (self.step
                   if self.batch_queue is None else self.step_with_batch_queue)

326
327
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
328
329
330
331
332
            # 1) Poll the input queue until there is work to do.
            if not self.scheduler.has_unfinished_requests():
                while True:
                    try:
                        req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
333
                        self._handle_client_request(*req)
334
335
336
                        break
                    except queue.Empty:
                        logger.debug("EngineCore busy loop waiting.")
337
338
339
                        # Break out the loop so we can log_stats in step().
                        if self.log_stats:
                            break
340
341
                    except BaseException:
                        raise
342

343
            # 2) Handle any new client requests.
344
345
            while not self.input_queue.empty():
                req = self.input_queue.get_nowait()
346
                self._handle_client_request(*req)
347
348

            # 3) Step the engine core.
349
            outputs = step_fn()
350

351
352
353
            # 4) Put EngineCoreOutputs into the output queue.
            if outputs is not None:
                self.output_queue.put_nowait(outputs)
354

355
356
357
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
358

359
        if request_type == EngineCoreRequestType.ADD:
360
            self.add_request(request)
361
        elif request_type == EngineCoreRequestType.ABORT:
362
            self.abort_requests(request)
363
364
365
366
        elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
            self.reset_prefix_cache()
        elif request_type == EngineCoreRequestType.PROFILE:
            self.model_executor.profile(request)
367
368
        elif request_type == EngineCoreRequestType.ADD_LORA:
            self.model_executor.add_lora(request)
369
370
371
372
373

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
374
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
375
        add_lora_decoder = MsgpackDecoder(LoRARequest)
376
        generic_decoder = MsgpackDecoder()
377

378
        with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
379
380
381
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
382
                request_type = EngineCoreRequestType(bytes(type_frame.buffer))
383
384

                # Deserialize the request data.
385
386
387
388
389
390
391
392
                decoder = None
                if request_type == EngineCoreRequestType.ADD:
                    decoder = add_request_decoder
                elif request_type == EngineCoreRequestType.ADD_LORA:
                    decoder = add_lora_decoder
                else:
                    decoder = generic_decoder

393
                request = decoder.decode(data_frame.buffer)
394
395

                # Push to input queue for core busy loop.
396
                self.input_queue.put_nowait((request_type, request))
397
398
399
400
401

    def process_output_socket(self, output_path: str):
        """Output socket IO thread."""

        # Msgpack serialization encoding.
402
        encoder = MsgpackEncoder()
403
404
405
        # Reuse send buffer.
        buffer = bytearray()

406
        with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
407
            while True:
408
                outputs = self.output_queue.get()
409
410
                encoder.encode_into(outputs, buffer)
                socket.send_multipart((buffer, ), copy=False)