engine.py 18.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
import pickle
import signal
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

import cloudpickle
import zmq

12
from vllm import AsyncEngineArgs, SamplingParams
13
from vllm.config import VllmConfig
Joe Runde's avatar
Joe Runde committed
14
from vllm.engine.llm_engine import LLMEngine
15
16
17
18
19
20
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
                                         IPC_HEALTH_EXT, IPC_INPUT_EXT,
                                         IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
                                         VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
21
                                         RPCAdapterLoadedResponse, RPCError,
22
23
                                         RPCIsSleepingRequest,
                                         RPCIsSleepingResponse,
24
                                         RPCLoadAdapterRequest,
25
                                         RPCProcessRequest,
26
                                         RPCResetMultiModalCacheRequest,
27
                                         RPCResetPrefixCacheRequest,
28
29
30
                                         RPCSleepRequest, RPCStartupRequest,
                                         RPCStartupResponse,
                                         RPCUProfileRequest, RPCWakeUpRequest)
31
32
33
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
34
35
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
36
from vllm.usage.usage_lib import UsageContext
37
from vllm.utils import deprecate_kwargs
38
from vllm.worker.model_runner_base import InputProcessingError
39
40
41
42
43
44
45
46

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )


class MQLLMEngine:
47
48
    """A multiprocessing wrapper for
    [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
49

50
51
    This class is used to wrap the
    [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
52
    in concurrnet manner. It runs a background loop and uses zeromq to
53
    receive new requests and stream outputs incrementally via ipc.
54

55
56
57
    The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
    process is kicked off when a new RPCProcessRequest is received by the
    input_socket.
58

59
60
    The self.engine_loop checks the input_socket for new requests,
    adds them to the LLMEngine if there are any, calls the internal
61
62
    [`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends
    the RequestOutputs back over the output_socket.
63
64
65
66
67
68
69
70
71
72

    If use_async_sockets is set, the logic associated with reading new
    requests from the socket and sending data to the socket is passed
    as a callback to the llm_engine, which calls the logic asynchronously
    such that the IPC can be overlapped with the GPU.

    Args:
        ipc_path: Base path for zeromq interprocess messaging
        use_async_sockets: Whether to make send/recv async with GPU
        log_requests: Whether to log the requests.
73
74
        *args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
        **kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
75
76
77
78
79
80
81
82
    """

    def __init__(self,
                 ipc_path: str,
                 use_async_sockets: bool,
                 *args,
                 log_requests: bool = True,
                 **kwargs) -> None:
83
84
85
        # For MQLLMEngine, we can use cached outputs, since each new request
        # output is immediately pickled and send over the socket, which frees
        # the python object to be reused again.
86
        kwargs['use_cached_outputs'] = True
87

88
        self.engine = LLMEngine(*args, **kwargs)
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        self.log_requests = log_requests

        self.use_async_sockets = use_async_sockets
        if self.use_async_sockets:
            self.engine.process_request_outputs_callback = \
                self._async_socket_engine_callback

        self.ctx = zmq.Context()  # type: ignore[attr-defined]

        # Receive input from the client.
        self.input_socket = self.ctx.socket(zmq.constants.PULL)
        self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")

        # Send output stream back to client.
        self.output_socket = self.ctx.socket(zmq.constants.PUSH)
        self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")

106
107
108
        # Send heartbeats back to client.
        self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
        self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

        # IPC path for the data socket.
        self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

        # Error state.
        self._errored_with: Optional[BaseException] = None

    @property
    def dead_error(self) -> BaseException:
        if self._errored_with is not None:
            return ENGINE_DEAD_ERROR(self._errored_with)
        else:
            return ENGINE_DEAD_ERROR()

    @classmethod
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    @deprecate_kwargs(
        "disable_log_requests",
        additional_message=("This argument will have no effect. "
                            "Use `enable_log_requests` instead."),
    )
    def from_vllm_config(
            cls,
            vllm_config: VllmConfig,
            usage_context: UsageContext,
            enable_log_requests: bool,
            disable_log_stats: bool,
            ipc_path: str,
            disable_log_requests: bool = True,  # Deprecated, will be removed
    ) -> "MQLLMEngine":
138
139
140
        # Setup plugins for each process
        from vllm.plugins import load_general_plugins
        load_general_plugins()
141

142
143
144
145
146
147
148
149
        use_async_sockets = vllm_config.model_config.use_async_output_proc

        return cls(
            vllm_config=vllm_config,
            executor_class=LLMEngine._get_executor_cls(vllm_config),
            ipc_path=ipc_path,
            usage_context=usage_context,
            use_async_sockets=use_async_sockets,
150
            log_requests=enable_log_requests,
151
152
            log_stats=(not disable_log_stats),
        )
153

154
155
156
157
    @staticmethod
    def from_engine_args(engine_args: AsyncEngineArgs,
                         usage_context: UsageContext, ipc_path: str):
        """Creates an MQLLMEngine from the engine arguments."""
158

159
160
161
162
163
        vllm_config = engine_args.create_engine_config(usage_context)
        return MQLLMEngine.from_vllm_config(
            ipc_path=ipc_path,
            vllm_config=vllm_config,
            usage_context=usage_context,
164
            enable_log_requests=engine_args.enable_log_requests,
165
166
            disable_log_stats=engine_args.disable_log_stats,
        )
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

    def start(self):
        try:
            try:
                logger.debug("Starting Startup Loop.")
                self.run_startup_loop()
                logger.debug("Starting Engine Loop.")
                self.run_engine_loop()
            except Exception as e:
                logger.exception(repr(e))
        except KeyboardInterrupt:
            logger.debug("Shutting down MQLLMEngine.")
        finally:
            logger.debug("MQLLMEngine is shut down.")
            self.cleanup()

    def cleanup(self):
        """Cleanup zeromq state on shutdown."""
        # Closes all sockets and destroys context.
        self.ctx.destroy(linger=0)
        del self.engine

    @contextmanager
    def make_data_socket(
            self) -> Iterator[zmq.Socket]:  # type: ignore[name-defined]
        socket = self.ctx.socket(zmq.constants.ROUTER)
        try:
            socket.bind(self.data_ipc_path)
            yield socket
        finally:
            socket.close(linger=0)

    def run_startup_loop(self) -> None:
        """Startup loop for sending data from Engine -> Client."""

        with self.make_data_socket() as socket:
            response: Union[RPCStartupResponse, BaseException]
            try:
                identity, message = socket.recv_multipart(copy=False)
                request: RPCStartupRequest = pickle.loads(message.buffer)

                # Handle the query from the Client.
                if request == RPCStartupRequest.IS_SERVER_READY:
                    tracing_enabled = self.engine.is_tracing_enabled()
                    response = RPCStartupResponse(
                        tracing_enabled=tracing_enabled)

            except Exception as e:
                response = e

            socket.send_multipart((identity, pickle.dumps(response)),
                                  copy=False)

    def run_engine_loop(self):
        """Core busy loop of the LLMEngine."""

        while True:
            if not self.engine.has_unfinished_requests():
                # Poll until there is work to do.
                while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
227
228
229
                    # When there's no work, check on engine health and send
                    # health status back to client
                    self._health_check()
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
                    self.engine.do_log_stats()
                    logger.debug("Waiting for new requests in engine loop.")

            # Handle any input from the client.
            self.handle_new_input()

            # Engine step.
            request_outputs = self.engine_step()

            # Send request outputs (if async, done in engine_step callback).
            if not self.use_async_sockets:
                self._send_outputs(request_outputs)

    def engine_step(self) -> List[RequestOutput]:
        """Engine step wrapper with error handling."""
        try:
            return self.engine.step()
        except SystemExit:
            raise
249
250
251
252
253
254
255
256
        except InputProcessingError as e:
            # Special case where we handle an error preparing the inputs for
            # a single request in the batch
            rpc_err = RPCError(request_id=e.request_id,
                               is_engine_errored=False,
                               exception=e.__cause__)
            self._send_outputs(rpc_err)
            return []
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        except BaseException as e:
            self._set_errored(e)
            rpc_err = RPCError(request_id=None,
                               is_engine_errored=True,
                               exception=e)
            self._send_outputs(rpc_err)
            raise e

    def handle_new_input(self):
        """Handle new input from the socket"""
        try:
            while self.input_socket.poll(timeout=0) != 0:
                frames = self.input_socket.recv_multipart(copy=False)
                request = pickle.loads(frames[0].buffer)

272
                if isinstance(request, RPCProcessRequest):
273
274
                    if len(frames) > 1:
                        # Use cloudpickle for logits processors
275
                        assert isinstance(request.params, SamplingParams)
276
                        lprocs = cloudpickle.loads(frames[1].buffer)
277
278
                        request.params.logits_processors = lprocs
                    self._handle_process_request(request)
279
280
                elif isinstance(request, RPCAbortRequest):
                    self._handle_abort_request(request)
281
282
283
284
285
                elif isinstance(request, RPCUProfileRequest):
                    if request == RPCUProfileRequest.START_PROFILE:
                        self.start_profile()
                    else:
                        self.stop_profile()
286
287
                elif isinstance(request, RPCLoadAdapterRequest):
                    self._handle_load_adapter_request(request)
288
289
                elif isinstance(request, RPCResetMultiModalCacheRequest):
                    self.reset_mm_cache()
290
291
                elif isinstance(request, RPCResetPrefixCacheRequest):
                    self.reset_prefix_cache()
292
293
294
                elif isinstance(request, RPCSleepRequest):
                    self.sleep(request.value)
                elif isinstance(request, RPCWakeUpRequest):
295
                    self.wake_up(request.tags)
296
297
                elif isinstance(request, RPCIsSleepingRequest):
                    self._handle_is_sleeping_request(request)
298
                else:
299
300
                    raise ValueError("Unknown RPCRequest Type: "
                                     f"{type(request)}")
301
302
303
304

        except Exception as e:
            self._set_errored(e)
            self._send_unhealthy(e)
305
            raise e from None
306

307
308
    def _handle_process_request(self, request: RPCProcessRequest):
        """Handle RPCProcessRequest by adding it to the LLMEngine."""
309
310
311
312
313
314
315
316
317
        request_id = request.request_id

        if self._errored_with is not None:
            rpc_err = RPCError(request_id=request_id,
                               is_engine_errored=True,
                               exception=ENGINE_DEAD_ERROR(self._errored_with))
            self._send_outputs(rpc_err)

        try:
318
319
320
321
322
323
            self.engine.add_request(request_id=request_id,
                                    prompt=request.prompt,
                                    params=request.params,
                                    lora_request=request.lora_request,
                                    trace_headers=request.trace_headers,
                                    priority=request.priority)
324
325
326
327
328
329
330
331

            if self.log_requests:
                logger.info("Added request %s.", request.request_id)

        except Exception as e:
            # We do not set self._errored = True here, since the error
            # is due to an issue adding this request to the engine,
            # rather than an issue with the engine itself.
332
333
            logger.debug("Failed to add request %s to engine. %s",
                         request.request_id, e)
334
335
336
337
338
339
340
341
342
343
344
345
346
347
            is_errored = self._errored_with is not None
            rpc_err = RPCError(request_id=request_id,
                               is_engine_errored=is_errored,
                               exception=e)
            self._send_outputs(rpc_err)

            # Remove request from the engine.
            self.engine.abort_request(request_id)

    def _handle_abort_request(self, request: RPCAbortRequest):
        self.engine.abort_request(request.request_id)
        if self.log_requests:
            logger.info("Aborted request %s.", request.request_id)

348
349
350
351
352
353
354
355
356
    def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
        try:
            self.engine.add_lora(request.lora_request)
        except BaseException as e:
            # Send back an error if the adater fails to load
            rpc_err = RPCError(request_id=request.request_id,
                               is_engine_errored=False,
                               exception=e)
            self._send_outputs(rpc_err)
357
            return
358
359
360
361
        # Otherwise, send back the successful load message
        self._send_outputs(
            RPCAdapterLoadedResponse(request_id=request.request_id))

362
363
364
365
366
367
    def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
        is_sleeping = self.is_sleeping()
        self._send_outputs(
            RPCIsSleepingResponse(request_id=request.request_id,
                                  is_sleeping=is_sleeping))

368
    def _health_check(self):
369
        # Send unhealthy if engine has already errored
370
371
        if self._errored_with is not None:
            self._send_unhealthy(self._errored_with)
372
373
374
375
376
377
        try:
            self.engine.check_health()
            self._send_healthy()
        except Exception as e:
            self._set_errored(e)
            self._send_unhealthy(e)
378
379

    def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
380
381
382
383
384
        """Send outputs back to the engine client. These can be:
        - Exceptions
        - A list of generation outputs
        - A response from loading a lora adapter
        """
385
        if outputs:
386
387
388
389
390
391
392
393
394
395
396
            try:
                from ray.exceptions import RayTaskError

                # RayTaskError might not pickelable here. We need to unpack the
                # underlying exception as the real exception in the output.
                if (isinstance(outputs, RPCError)
                        and isinstance(outputs.exception, RayTaskError)):
                    outputs.exception = outputs.exception.cause
            except ImportError:
                pass

397
398
399
400
401
            output_bytes = pickle.dumps(outputs)
            self.output_socket.send_multipart((output_bytes, ), copy=False)

    def _send_healthy(self):
        """Send HEALTHY message to RPCClient."""
402
403
        if not self.heartbeat_socket.closed:
            self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
404
405
406

    def _send_unhealthy(self, error: BaseException):
        """Send UNHEALTHY message to RPCClient."""
407
408
409
        if not self.heartbeat_socket.closed:
            error_bytes = pickle.dumps(error)
            self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
410
411
412
413
414
415
416
417
418
419
420
421

    def _async_socket_engine_callback(self,
                                      request_outputs: REQUEST_OUTPUTS_T):
        """Callback used by engine to make socket handling async with GPU."""
        self._send_outputs(request_outputs)
        self.handle_new_input()

    def _set_errored(self, e: BaseException):
        """Log and set errored status if this is the first issue."""
        if self._errored_with is None:
            self._errored_with = e

422
    def start_profile(self) -> None:
423
        self.engine.start_profile()
424
425

    def stop_profile(self) -> None:
426
        self.engine.stop_profile()
427

428
429
430
    def reset_mm_cache(self) -> bool:
        return self.engine.reset_mm_cache()

431
432
433
    def reset_prefix_cache(self) -> bool:
        return self.engine.reset_prefix_cache()

434
435
436
    def sleep(self, level: int = 1) -> None:
        self.engine.sleep(level)

437
438
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
439

440
441
442
    def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

443

444
445
446
447
def signal_handler(*_) -> None:
    raise KeyboardInterrupt("MQLLMEngine terminated")


448
449
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
                  ipc_path: str, disable_log_stats: bool,
450
                  enable_log_requests: bool, engine_alive):
451
    try:
452
453
454
        # Ensure we can serialize transformer config before spawning
        maybe_register_config_serialize_by_value()

455
456
457
458
        engine = MQLLMEngine.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            disable_log_stats=disable_log_stats,
459
            enable_log_requests=enable_log_requests,
460
            ipc_path=ipc_path)
461

462
        signal.signal(signal.SIGTERM, signal_handler)
463

464
        engine.start()
465

466
467
468
    except BaseException as e:
        logger.exception(e)
        engine_alive.value = False
469
        raise e from None