engine.py 19.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
import pickle
import signal
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

import cloudpickle
10
import vllm.envs as envs
lizhigong's avatar
lizhigong committed
11
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
12
13
import zmq

14
from vllm import AsyncEngineArgs, SamplingParams
15
from vllm.config import VllmConfig
Joe Runde's avatar
Joe Runde committed
16
from vllm.engine.llm_engine import LLMEngine
17
18
19
20
21
22
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
                                         IPC_HEALTH_EXT, IPC_INPUT_EXT,
                                         IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
                                         VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
23
                                         RPCAdapterLoadedResponse, RPCError,
24
25
                                         RPCIsSleepingRequest,
                                         RPCIsSleepingResponse,
26
                                         RPCLoadAdapterRequest,
27
                                         RPCProcessRequest,
28
                                         RPCResetMultiModalCacheRequest,
29
                                         RPCResetPrefixCacheRequest,
30
31
32
                                         RPCSleepRequest, RPCStartupRequest,
                                         RPCStartupResponse,
                                         RPCUProfileRequest, RPCWakeUpRequest)
33
34
35
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
36
37
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
38
from vllm.usage.usage_lib import UsageContext
39
from vllm.utils import deprecate_kwargs
40
from vllm.worker.model_runner_base import InputProcessingError
41
import time
42
43
44
45
46
47
48
49

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )


class MQLLMEngine:
50
51
    """A multiprocessing wrapper for
    [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
52

53
54
    This class is used to wrap the
    [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
55
    in concurrnet manner. It runs a background loop and uses zeromq to
56
    receive new requests and stream outputs incrementally via ipc.
57

58
59
60
    The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
    process is kicked off when a new RPCProcessRequest is received by the
    input_socket.
61

62
63
    The self.engine_loop checks the input_socket for new requests,
    adds them to the LLMEngine if there are any, calls the internal
64
65
    [`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends
    the RequestOutputs back over the output_socket.
66
67
68
69
70
71
72
73
74
75

    If use_async_sockets is set, the logic associated with reading new
    requests from the socket and sending data to the socket is passed
    as a callback to the llm_engine, which calls the logic asynchronously
    such that the IPC can be overlapped with the GPU.

    Args:
        ipc_path: Base path for zeromq interprocess messaging
        use_async_sockets: Whether to make send/recv async with GPU
        log_requests: Whether to log the requests.
76
77
        *args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
        **kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
78
79
80
81
82
83
84
85
    """

    def __init__(self,
                 ipc_path: str,
                 use_async_sockets: bool,
                 *args,
                 log_requests: bool = True,
                 **kwargs) -> None:
86
87
88
        # For MQLLMEngine, we can use cached outputs, since each new request
        # output is immediately pickled and send over the socket, which frees
        # the python object to be reused again.
89
        kwargs['use_cached_outputs'] = True
90

91
        if envs.VLLM_ZERO_OVERHEAD:
lizhigong's avatar
lizhigong committed
92
93
94
            self.engine = ZeroOverheadEngine(*args, **kwargs)
        else:
            self.engine = LLMEngine(*args, **kwargs)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        self.log_requests = log_requests

        self.use_async_sockets = use_async_sockets
        if self.use_async_sockets:
            self.engine.process_request_outputs_callback = \
                self._async_socket_engine_callback

        self.ctx = zmq.Context()  # type: ignore[attr-defined]

        # Receive input from the client.
        self.input_socket = self.ctx.socket(zmq.constants.PULL)
        self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")

        # Send output stream back to client.
        self.output_socket = self.ctx.socket(zmq.constants.PUSH)
        self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")

112
113
114
        # Send heartbeats back to client.
        self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
        self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

        # IPC path for the data socket.
        self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

        # Error state.
        self._errored_with: Optional[BaseException] = None

    @property
    def dead_error(self) -> BaseException:
        if self._errored_with is not None:
            return ENGINE_DEAD_ERROR(self._errored_with)
        else:
            return ENGINE_DEAD_ERROR()

    @classmethod
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    @deprecate_kwargs(
        "disable_log_requests",
        additional_message=("This argument will have no effect. "
                            "Use `enable_log_requests` instead."),
    )
    def from_vllm_config(
            cls,
            vllm_config: VllmConfig,
            usage_context: UsageContext,
            enable_log_requests: bool,
            disable_log_stats: bool,
            ipc_path: str,
            disable_log_requests: bool = True,  # Deprecated, will be removed
    ) -> "MQLLMEngine":
144
145
146
        # Setup plugins for each process
        from vllm.plugins import load_general_plugins
        load_general_plugins()
147

148
149
150
151
152
153
154
155
        use_async_sockets = vllm_config.model_config.use_async_output_proc

        return cls(
            vllm_config=vllm_config,
            executor_class=LLMEngine._get_executor_cls(vllm_config),
            ipc_path=ipc_path,
            usage_context=usage_context,
            use_async_sockets=use_async_sockets,
156
            log_requests=enable_log_requests,
157
158
            log_stats=(not disable_log_stats),
        )
159

160
161
162
163
    @staticmethod
    def from_engine_args(engine_args: AsyncEngineArgs,
                         usage_context: UsageContext, ipc_path: str):
        """Creates an MQLLMEngine from the engine arguments."""
164

165
166
167
168
169
        vllm_config = engine_args.create_engine_config(usage_context)
        return MQLLMEngine.from_vllm_config(
            ipc_path=ipc_path,
            vllm_config=vllm_config,
            usage_context=usage_context,
170
            enable_log_requests=engine_args.enable_log_requests,
171
172
            disable_log_stats=engine_args.disable_log_stats,
        )
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

    def start(self):
        try:
            try:
                logger.debug("Starting Startup Loop.")
                self.run_startup_loop()
                logger.debug("Starting Engine Loop.")
                self.run_engine_loop()
            except Exception as e:
                logger.exception(repr(e))
        except KeyboardInterrupt:
            logger.debug("Shutting down MQLLMEngine.")
        finally:
            logger.debug("MQLLMEngine is shut down.")
            self.cleanup()

    def cleanup(self):
        """Cleanup zeromq state on shutdown."""
        # Closes all sockets and destroys context.
        self.ctx.destroy(linger=0)
        del self.engine

    @contextmanager
    def make_data_socket(
            self) -> Iterator[zmq.Socket]:  # type: ignore[name-defined]
        socket = self.ctx.socket(zmq.constants.ROUTER)
        try:
            socket.bind(self.data_ipc_path)
            yield socket
        finally:
            socket.close(linger=0)

    def run_startup_loop(self) -> None:
        """Startup loop for sending data from Engine -> Client."""

        with self.make_data_socket() as socket:
            response: Union[RPCStartupResponse, BaseException]
            try:
                identity, message = socket.recv_multipart(copy=False)
                request: RPCStartupRequest = pickle.loads(message.buffer)

                # Handle the query from the Client.
                if request == RPCStartupRequest.IS_SERVER_READY:
                    tracing_enabled = self.engine.is_tracing_enabled()
                    response = RPCStartupResponse(
                        tracing_enabled=tracing_enabled)

            except Exception as e:
                response = e

            socket.send_multipart((identity, pickle.dumps(response)),
                                  copy=False)

    def run_engine_loop(self):
        """Core busy loop of the LLMEngine."""

229
230
        last_no_req_time_refreshed = True
        last_no_req_time = time.perf_counter()
231
232
233
234
        while True:
            if not self.engine.has_unfinished_requests():
                # Poll until there is work to do.
                while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
235
236
237
                    # When there's no work, check on engine health and send
                    # health status back to client
                    self._health_check()
238
239
                    self.engine.do_log_stats()
                    logger.debug("Waiting for new requests in engine loop.")
240
241
                last_no_req_time = time.perf_counter()
                last_no_req_time_refreshed = True
242
243
244
245

            # Handle any input from the client.
            self.handle_new_input()

246
247
248
249
250
251
252
253
            if envs.VLLM_TBO_REQ_DELAY_MS > 0 and last_no_req_time_refreshed and envs.VLLM_ENABLE_TBO:
                if self.engine.get_num_unfinished_requests() < 2:
                    time_diff_ms = int((time.perf_counter() - last_no_req_time) * 1000)
                    if time_diff_ms < envs.VLLM_TBO_REQ_DELAY_MS:
                        time.sleep(0.01) # sleep and waiting more request to merge in one batch
                        continue

            last_no_req_time_refreshed = False
254
255
256
257
258
259
260
261
262
263
264
265
266
            # Engine step.
            request_outputs = self.engine_step()

            # Send request outputs (if async, done in engine_step callback).
            if not self.use_async_sockets:
                self._send_outputs(request_outputs)

    def engine_step(self) -> List[RequestOutput]:
        """Engine step wrapper with error handling."""
        try:
            return self.engine.step()
        except SystemExit:
            raise
267
268
269
270
271
272
273
274
        except InputProcessingError as e:
            # Special case where we handle an error preparing the inputs for
            # a single request in the batch
            rpc_err = RPCError(request_id=e.request_id,
                               is_engine_errored=False,
                               exception=e.__cause__)
            self._send_outputs(rpc_err)
            return []
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        except BaseException as e:
            self._set_errored(e)
            rpc_err = RPCError(request_id=None,
                               is_engine_errored=True,
                               exception=e)
            self._send_outputs(rpc_err)
            raise e

    def handle_new_input(self):
        """Handle new input from the socket"""
        try:
            while self.input_socket.poll(timeout=0) != 0:
                frames = self.input_socket.recv_multipart(copy=False)
                request = pickle.loads(frames[0].buffer)

290
                if isinstance(request, RPCProcessRequest):
291
292
                    if len(frames) > 1:
                        # Use cloudpickle for logits processors
293
                        assert isinstance(request.params, SamplingParams)
294
                        lprocs = cloudpickle.loads(frames[1].buffer)
295
296
                        request.params.logits_processors = lprocs
                    self._handle_process_request(request)
297
298
                elif isinstance(request, RPCAbortRequest):
                    self._handle_abort_request(request)
299
300
301
302
303
                elif isinstance(request, RPCUProfileRequest):
                    if request == RPCUProfileRequest.START_PROFILE:
                        self.start_profile()
                    else:
                        self.stop_profile()
304
305
                elif isinstance(request, RPCLoadAdapterRequest):
                    self._handle_load_adapter_request(request)
306
307
                elif isinstance(request, RPCResetMultiModalCacheRequest):
                    self.reset_mm_cache()
308
309
                elif isinstance(request, RPCResetPrefixCacheRequest):
                    self.reset_prefix_cache()
310
311
312
                elif isinstance(request, RPCSleepRequest):
                    self.sleep(request.value)
                elif isinstance(request, RPCWakeUpRequest):
313
                    self.wake_up(request.tags)
314
315
                elif isinstance(request, RPCIsSleepingRequest):
                    self._handle_is_sleeping_request(request)
316
                else:
317
318
                    raise ValueError("Unknown RPCRequest Type: "
                                     f"{type(request)}")
319
320
321
322

        except Exception as e:
            self._set_errored(e)
            self._send_unhealthy(e)
323
            raise e from None
324

325
326
    def _handle_process_request(self, request: RPCProcessRequest):
        """Handle RPCProcessRequest by adding it to the LLMEngine."""
327
328
329
330
331
332
333
334
335
        request_id = request.request_id

        if self._errored_with is not None:
            rpc_err = RPCError(request_id=request_id,
                               is_engine_errored=True,
                               exception=ENGINE_DEAD_ERROR(self._errored_with))
            self._send_outputs(rpc_err)

        try:
336
337
338
339
340
341
            self.engine.add_request(request_id=request_id,
                                    prompt=request.prompt,
                                    params=request.params,
                                    lora_request=request.lora_request,
                                    trace_headers=request.trace_headers,
                                    priority=request.priority)
342

zhuwenwen's avatar
zhuwenwen committed
343
344
            if self.log_requests:
                logger.info("Added request %s.", request.request_id)
345
346
347
348
349

        except Exception as e:
            # We do not set self._errored = True here, since the error
            # is due to an issue adding this request to the engine,
            # rather than an issue with the engine itself.
350
351
            logger.debug("Failed to add request %s to engine. %s",
                         request.request_id, e)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
            is_errored = self._errored_with is not None
            rpc_err = RPCError(request_id=request_id,
                               is_engine_errored=is_errored,
                               exception=e)
            self._send_outputs(rpc_err)

            # Remove request from the engine.
            self.engine.abort_request(request_id)

    def _handle_abort_request(self, request: RPCAbortRequest):
        self.engine.abort_request(request.request_id)
        if self.log_requests:
            logger.info("Aborted request %s.", request.request_id)

366
367
    def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
        try:
368
            lora_loaded = self.engine.add_lora(request.lora_request)
369
370
371
372
373
374
        except BaseException as e:
            # Send back an error if the adater fails to load
            rpc_err = RPCError(request_id=request.request_id,
                               is_engine_errored=False,
                               exception=e)
            self._send_outputs(rpc_err)
375
            return
376
377
        # Otherwise, send back the successful load message
        self._send_outputs(
378
379
            RPCAdapterLoadedResponse(request_id=request.request_id,
                                     lora_loaded=lora_loaded))
380

381
382
383
384
385
386
    def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
        is_sleeping = self.is_sleeping()
        self._send_outputs(
            RPCIsSleepingResponse(request_id=request.request_id,
                                  is_sleeping=is_sleeping))

387
    def _health_check(self):
388
        # Send unhealthy if engine has already errored
389
390
        if self._errored_with is not None:
            self._send_unhealthy(self._errored_with)
391
392
393
394
395
396
        try:
            self.engine.check_health()
            self._send_healthy()
        except Exception as e:
            self._set_errored(e)
            self._send_unhealthy(e)
397
398

    def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
399
400
401
402
403
        """Send outputs back to the engine client. These can be:
        - Exceptions
        - A list of generation outputs
        - A response from loading a lora adapter
        """
404
        if outputs:
405
406
407
408
409
410
411
412
413
414
415
            try:
                from ray.exceptions import RayTaskError

                # RayTaskError might not pickelable here. We need to unpack the
                # underlying exception as the real exception in the output.
                if (isinstance(outputs, RPCError)
                        and isinstance(outputs.exception, RayTaskError)):
                    outputs.exception = outputs.exception.cause
            except ImportError:
                pass

416
417
418
419
420
            output_bytes = pickle.dumps(outputs)
            self.output_socket.send_multipart((output_bytes, ), copy=False)

    def _send_healthy(self):
        """Send HEALTHY message to RPCClient."""
421
422
        if not self.heartbeat_socket.closed:
            self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
423
424
425

    def _send_unhealthy(self, error: BaseException):
        """Send UNHEALTHY message to RPCClient."""
426
427
428
        if not self.heartbeat_socket.closed:
            error_bytes = pickle.dumps(error)
            self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
429
430
431
432
433
434
435
436
437
438
439
440

    def _async_socket_engine_callback(self,
                                      request_outputs: REQUEST_OUTPUTS_T):
        """Callback used by engine to make socket handling async with GPU."""
        self._send_outputs(request_outputs)
        self.handle_new_input()

    def _set_errored(self, e: BaseException):
        """Log and set errored status if this is the first issue."""
        if self._errored_with is None:
            self._errored_with = e

441
    def start_profile(self) -> None:
442
        self.engine.start_profile()
443
444

    def stop_profile(self) -> None:
445
        self.engine.stop_profile()
446

447
448
449
    def reset_mm_cache(self) -> bool:
        return self.engine.reset_mm_cache()

450
451
452
    def reset_prefix_cache(self) -> bool:
        return self.engine.reset_prefix_cache()

453
454
455
    def sleep(self, level: int = 1) -> None:
        self.engine.sleep(level)

456
457
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
458

459
460
461
    def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

462

463
464
465
466
def signal_handler(*_) -> None:
    raise KeyboardInterrupt("MQLLMEngine terminated")


467
468
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
                  ipc_path: str, disable_log_stats: bool,
469
                  enable_log_requests: bool, engine_alive):
470
    try:
471
472
473
        # Ensure we can serialize transformer config before spawning
        maybe_register_config_serialize_by_value()

474
475
476
477
        engine = MQLLMEngine.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            disable_log_stats=disable_log_stats,
478
            enable_log_requests=enable_log_requests,
479
            ipc_path=ipc_path)
480

481
        signal.signal(signal.SIGTERM, signal_handler)
482

483
        engine.start()
484

485
486
487
    except BaseException as e:
        logger.exception(e)
        engine_alive.value = False
488
        raise e from None