"vscode:/vscode.git/clone" did not exist on "41ca62cf03b31deb68dbc14e4a92a1d4579de08b"
core.py 17.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import queue
4
import signal
5
6
import threading
import time
7
from concurrent.futures import Future
8
from inspect import isclass, signature
9
from multiprocessing.connection import Connection
10
from typing import Any, Optional
11

12
import msgspec
13
import psutil
14
15
16
import zmq
import zmq.asyncio

17
from vllm.config import VllmConfig
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
21
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
22
23
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
                        zmq_socket_ctx)
24
25
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
                                         unify_kv_cache_configs)
26
27
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
28
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
29
                            EngineCoreRequestType, UtilityOutput)
30
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
31
from vllm.v1.executor.abstract import Executor
32
from vllm.v1.outputs import ModelRunnerOutput
33
from vllm.v1.request import Request, RequestStatus
34
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
35
from vllm.v1.structured_output import StructuredOutputManager
36
37
38
39
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)

40
POLLING_TIMEOUT_S = 2.5
41
42
43
44
45
46
47
48


class EngineCore:
    """Inner loop of vLLM's Engine."""

    def __init__(
        self,
        vllm_config: VllmConfig,
49
        executor_class: type[Executor],
50
        log_stats: bool,
51
    ):
52
        assert vllm_config.model_config.runner_type != "pooling"
53

54
        logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
55
56
                    VLLM_VERSION, vllm_config)

57
58
        self.log_stats = log_stats

59
60
61
62
63
        # Setup Model.
        self.model_executor = executor_class(vllm_config)

        # Setup KV Caches and update CacheConfig after profiling.
        num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
64
            vllm_config)
65
66
67
        vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
        vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

68
69
        self.structured_output_manager = StructuredOutputManager(vllm_config)

70
        # Setup scheduler.
71
        if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
72
73
74
75
76
77
78
79
80
            Scheduler = resolve_obj_by_qualname(
                vllm_config.scheduler_config.scheduler_cls)
        else:
            Scheduler = vllm_config.scheduler_config.scheduler_cls

        # This warning can be removed once the V1 Scheduler interface is
        # finalized and we can maintain support for scheduler classes that
        # implement it
        if Scheduler is not V1Scheduler:
81
82
83
84
85
            logger.warning(
                "Using configured V1 scheduler class %s. "
                "This scheduler interface is not public and "
                "compatibility may not be maintained.",
                vllm_config.scheduler_config.scheduler_cls)
86

87
88
89
90
91
        self.scheduler = Scheduler(
            scheduler_config=vllm_config.scheduler_config,
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            lora_config=vllm_config.lora_config,
92
            speculative_config=vllm_config.speculative_config,
93
            log_stats=self.log_stats,
94
            structured_output_manager=self.structured_output_manager,
95
        )
96

97
        # Setup MM Input Mapper.
98
        self.mm_input_cache_server = MMInputCacheServer(
99
            vllm_config.model_config)
100

101
102
103
104
105
        # Setup batch queue for pipeline parallelism.
        # Batch queue for scheduled batches. This enables us to asynchronously
        # schedule and execute batches, and is required by pipeline parallelism
        # to eliminate pipeline bubbles.
        self.batch_queue_size = self.model_executor.max_concurrent_batches
106
        self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
107
108
109
110
111
112
                                                     SchedulerOutput]]] = None
        if self.batch_queue_size > 1:
            logger.info("Batch queue is enabled with size %d",
                        self.batch_queue_size)
            self.batch_queue = queue.Queue(self.batch_queue_size)

113
    def _initialize_kv_caches(self,
114
                              vllm_config: VllmConfig) -> tuple[int, int]:
115
        start = time.time()
116

117
        # Get all kv cache needed by the model
118
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
119
120
121

        # Profiles the peak memory usage of the model to determine how much
        # memory can be allocated for kv cache.
122
        available_gpu_memory = self.model_executor.determine_available_memory()
123

124
        assert len(kv_cache_specs) == len(available_gpu_memory)
125
        # Get the kv cache tensor size
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        kv_cache_configs = [
            get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
                                available_gpu_memory_one_worker)
            for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
            zip(kv_cache_specs, available_gpu_memory)
        ]

        # Since we use a shared centralized controller, we need the
        # `kv_cache_config` to be consistent across all workers to make sure
        # all the memory operators can be applied to all workers.
        unify_kv_cache_configs(kv_cache_configs)

        # All workers have the same kv_cache_config except layer names, so use
        # an arbitrary one to get the number of blocks.
        assert all([
            cfg.num_blocks == kv_cache_configs[0].num_blocks
            for cfg in kv_cache_configs
        ])
        num_gpu_blocks = kv_cache_configs[0].num_blocks
145
        num_cpu_blocks = 0
146
147

        # Initialize kv cache and warmup the execution
148
        self.model_executor.initialize_from_config(kv_cache_configs)
149

150
151
152
        elapsed = time.time() - start
        logger.info(("init engine (profile, create kv cache, "
                     "warmup model) took %.2f seconds"), elapsed)
153
154
155
156
        return num_gpu_blocks, num_cpu_blocks

    def add_request(self, request: EngineCoreRequest):
        """Add request to the scheduler."""
157
158

        if request.mm_hashes is not None:
159
160
161
162
163
            # Here, if hash exists for a multimodal input, then it will be
            # fetched from the cache, else it will be added to the cache.
            # Note that the cache here is mirrored with the client cache, so
            # anything that has a hash must have a HIT cache entry here
            # as well.
164
            assert request.mm_inputs is not None
165
            request.mm_inputs = self.mm_input_cache_server.get_and_update(
166
                request.mm_inputs, request.mm_hashes)
167

168
        req = Request.from_engine_core_request(request)
169
170
        if req.use_structured_output:
            # Start grammar compilation asynchronously
171
            self.structured_output_manager.grammar_init(req)
172

173
174
        self.scheduler.add_request(req)

175
    def abort_requests(self, request_ids: list[str]):
176
177
178
179
180
181
182
183
        """Abort requests from the scheduler."""

        # TODO: The scheduler doesn't really need to know the
        # specific finish reason, TBD whether we propagate that
        # (i.e. client-aborted vs stop criteria met).
        self.scheduler.finish_requests(request_ids,
                                       RequestStatus.FINISHED_ABORTED)

184
    def step(self) -> EngineCoreOutputs:
185
186
        """Schedule, execute, and make output."""

187
188
189
        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
190
            return EngineCoreOutputs(
191
192
193
                outputs=[],
                scheduler_stats=self.scheduler.make_stats(),
            )
194
195
196
        scheduler_output = self.scheduler.schedule()
        output = self.model_executor.execute_model(scheduler_output)
        engine_core_outputs = self.scheduler.update_from_output(
197
            scheduler_output, output)  # type: ignore
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        return engine_core_outputs

    def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
        """Schedule and execute batches with the batch queue.
        Note that if nothing to output in this step, None is returned.

        The execution flow is as follows:
        1. Try to schedule a new batch if there are unscheduled requests
        and the job queue is not full. If a new batch is scheduled, directly
        return an empty engine core output. In other words, we won't check
        and return model outputs before the batch queue is full.
        2. If there is no new scheduled batch, meaning that the batch queue
        is full or no other requests can be scheduled, we block until the first
        batch in the job queue is finished.
        3. Update the scheduler from the output.
        """
        assert self.batch_queue is not None

        engine_core_outputs = None
        scheduler_output = None
        # If there are unscheduled requests and the job queue
        # is not full, schedule a new batch. Note that this is not blocking.
        if (self.scheduler.get_num_unscheduled_requests() > 0
                and not self.batch_queue.full()):
            scheduler_output = self.scheduler.schedule()
            if scheduler_output.total_num_scheduled_tokens > 0:
                future = self.model_executor.execute_model(scheduler_output)
                self.batch_queue.put_nowait(
                    (future, scheduler_output))  # type: ignore

229
230
231
232
        scheduled_batch = (scheduler_output is not None
                           and scheduler_output.total_num_scheduled_tokens > 0)

        # If no more requests can be scheduled and the job queue is not empty,
233
        # block until the first batch in the job queue is finished.
234
235
236
237
238
239
240
        if not scheduled_batch and not self.batch_queue.empty():
            future, scheduler_output = self.batch_queue.get_nowait()
            # Blocking until the first result is available.
            model_output = future.result()
            self.batch_queue.task_done()
            engine_core_outputs = self.scheduler.update_from_output(
                scheduler_output, model_output)
241

242
243
        return engine_core_outputs

244
245
246
    def shutdown(self):
        self.model_executor.shutdown()

247
    def profile(self, is_start: bool = True):
248
        self.model_executor.profile(is_start)
249

250
251
252
    def reset_prefix_cache(self):
        self.scheduler.reset_prefix_cache()

253
254
255
256
257
258
    def sleep(self, level: int = 1):
        self.model_executor.sleep(level)

    def wake_up(self):
        self.model_executor.wake_up()

259
260
261
    def is_sleeping(self) -> bool:
        return self.model_executor.is_sleeping

262
263
264
    def execute_dummy_batch(self):
        self.model_executor.collective_rpc("execute_dummy_batch")

265
266
267
268
269
270
    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.model_executor.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.model_executor.remove_lora(lora_id)

271
    def list_loras(self) -> set[int]:
272
273
274
275
        return self.model_executor.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        return self.model_executor.pin_lora(lora_id)
276

277
278
279
280
281
282
283
284

class EngineCoreProc(EngineCore):
    """ZMQ-wrapper for running EngineCore in background process."""

    def __init__(
        self,
        input_path: str,
        output_path: str,
285
286
        ready_pipe: Connection,
        vllm_config: VllmConfig,
287
        executor_class: type[Executor],
288
        log_stats: bool,
289
    ):
290
        super().__init__(vllm_config, executor_class, log_stats)
291
292
293
294
295
296

        # Background Threads and Queues for IO. These enable us to
        # overlap ZMQ socket IO with GPU since they release the GIL,
        # and to overlap some serialization/deserialization with the
        # model forward pass.
        # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
297
        self.input_queue: queue.Queue[tuple[EngineCoreRequestType,
298
                                            Any]] = queue.Queue()
299
        self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
300
301
302
303
304
305
306
307
        threading.Thread(target=self.process_input_socket,
                         args=(input_path, ),
                         daemon=True).start()
        threading.Thread(target=self.process_output_socket,
                         args=(output_path, ),
                         daemon=True).start()

        # Send Readiness signal to EngineClient.
308
        ready_pipe.send({"status": "READY"})
309
310
311
312
313

    @staticmethod
    def run_engine_core(*args, **kwargs):
        """Launch EngineCore busy loop in background process."""

314
315
316
317
318
        # Signal handler used for graceful termination.
        # SystemExit exception is only raised once to allow this and worker
        # processes to terminate without error
        shutdown_requested = False

319
320
321
        # Ensure we can serialize transformer config after spawning
        maybe_register_config_serialize_by_value()

322
323
324
325
326
327
328
329
330
331
        def signal_handler(signum, frame):
            nonlocal shutdown_requested
            if not shutdown_requested:
                shutdown_requested = True
                raise SystemExit()

        # Either SIGTERM or SIGINT will terminate the engine_core
        signal.signal(signal.SIGTERM, signal_handler)
        signal.signal(signal.SIGINT, signal_handler)

332
        parent_process = psutil.Process().parent()
333
        engine_core = None
334
335
336
337
        try:
            engine_core = EngineCoreProc(*args, **kwargs)
            engine_core.run_busy_loop()

338
        except SystemExit:
339
340
            logger.debug("EngineCore interrupted.")

341
342
343
        except Exception:
            traceback = get_exception_traceback()
            logger.error("EngineCore hit an exception: %s", traceback)
344
            parent_process.send_signal(signal.SIGUSR1)
345

346
347
348
349
        finally:
            if engine_core is not None:
                engine_core.shutdown()

350
351
352
    def run_busy_loop(self):
        """Core busy loop of the EngineCore."""

353
354
355
        step_fn = (self.step
                   if self.batch_queue is None else self.step_with_batch_queue)

356
357
        # Loop until process is sent a SIGINT or SIGTERM
        while True:
358
            # 1) Poll the input queue until there is work to do.
359
            while not self.scheduler.has_requests():
360
361
362
                logger.debug("EngineCore busy loop waiting.")
                req = self.input_queue.get()
                self._handle_client_request(*req)
363

364
            # 2) Handle any new client requests.
365
366
            while not self.input_queue.empty():
                req = self.input_queue.get_nowait()
367
                self._handle_client_request(*req)
368
369

            # 3) Step the engine core.
370
            outputs = step_fn()
371

372
373
374
            # 4) Put EngineCoreOutputs into the output queue.
            if outputs is not None:
                self.output_queue.put_nowait(outputs)
375

376
377
378
    def _handle_client_request(self, request_type: EngineCoreRequestType,
                               request: Any) -> None:
        """Dispatch request from client."""
379

380
        if request_type == EngineCoreRequestType.ADD:
381
            self.add_request(request)
382
        elif request_type == EngineCoreRequestType.ABORT:
383
            self.abort_requests(request)
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
        elif request_type == EngineCoreRequestType.UTILITY:
            call_id, method_name, args = request
            output = UtilityOutput(call_id)
            try:
                method = getattr(self, method_name)
                output.result = method(
                    *self._convert_msgspec_args(method, args))
            except BaseException as e:
                logger.exception("Invocation of %s method failed", method_name)
                output.failure_message = (f"Call to {method_name} method"
                                          f" failed: {str(e)}")
            self.output_queue.put_nowait(
                EngineCoreOutputs(utility_output=output))

    @staticmethod
    def _convert_msgspec_args(method, args):
        """If a provided arg type doesn't match corresponding target method
         arg type, try converting to msgspec object."""
        if not args:
            return args
        arg_types = signature(method).parameters.values()
        assert len(args) <= len(arg_types)
        return tuple(
            msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
            and issubclass(p.annotation, msgspec.Struct)
            and not isinstance(v, p.annotation) else v
            for v, p in zip(args, arg_types))
411
412
413
414
415

    def process_input_socket(self, input_path: str):
        """Input socket IO thread."""

        # Msgpack serialization decoding.
416
417
        add_request_decoder = MsgpackDecoder(EngineCoreRequest)
        generic_decoder = MsgpackDecoder()
418

419
        with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
420
421
422
            while True:
                # (RequestType, RequestData)
                type_frame, data_frame = socket.recv_multipart(copy=False)
423
                request_type = EngineCoreRequestType(bytes(type_frame.buffer))
424
425

                # Deserialize the request data.
426
427
428
                decoder = add_request_decoder if (
                    request_type
                    == EngineCoreRequestType.ADD) else generic_decoder
429
                request = decoder.decode(data_frame.buffer)
430
431

                # Push to input queue for core busy loop.
432
                self.input_queue.put_nowait((request_type, request))
433
434
435
436
437

    def process_output_socket(self, output_path: str):
        """Output socket IO thread."""

        # Msgpack serialization encoding.
438
        encoder = MsgpackEncoder()
439
440
441
        # Reuse send buffer.
        buffer = bytearray()

442
        with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
443
            while True:
444
                outputs = self.output_queue.get()
445
446
                encoder.encode_into(outputs, buffer)
                socket.send_multipart((buffer, ), copy=False)