engine.py 17.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
import pickle
import signal
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

import cloudpickle
import zmq

11
from vllm import AsyncEngineArgs, SamplingParams
12
from vllm.config import VllmConfig
Joe Runde's avatar
Joe Runde committed
13
from vllm.engine.llm_engine import LLMEngine
14
15
16
17
18
19
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
                                         IPC_HEALTH_EXT, IPC_INPUT_EXT,
                                         IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
                                         VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
20
                                         RPCAdapterLoadedResponse, RPCError,
21
22
                                         RPCIsSleepingRequest,
                                         RPCIsSleepingResponse,
23
                                         RPCLoadAdapterRequest,
24
                                         RPCProcessRequest,
25
                                         RPCResetMultiModalCacheRequest,
26
                                         RPCResetPrefixCacheRequest,
27
28
29
                                         RPCSleepRequest, RPCStartupRequest,
                                         RPCStartupResponse,
                                         RPCUProfileRequest, RPCWakeUpRequest)
30
31
32
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
33
34
from vllm.transformers_utils.config import (
    maybe_register_config_serialize_by_value)
35
from vllm.usage.usage_lib import UsageContext
36
from vllm.worker.model_runner_base import InputProcessingError
37
38
39
40
41
42
43
44

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )


class MQLLMEngine:
45
    """A multiprocessing wrapper for {class}`LLMEngine`.
46

47
    This class is used to wrap the {class}`LLMEngine` class to enable use
48
    in concurrnet manner. It runs a background loop and uses zeromq to
49
    receive new requests and stream outputs incrementally via ipc.
50

51
    The {class}`LLMEngine` generate or encode process is kicked off when a new
52
    RPCProcessRequest is received by the input_socket.
53

54
55
    The self.engine_loop checks the input_socket for new requests,
    adds them to the LLMEngine if there are any, calls the internal
56
    {class}`LLMEngine.step()`, and sends the RequestOutputs back over
57
58
59
60
61
62
63
64
65
66
67
    the output_socket.

    If use_async_sockets is set, the logic associated with reading new
    requests from the socket and sending data to the socket is passed
    as a callback to the llm_engine, which calls the logic asynchronously
    such that the IPC can be overlapped with the GPU.

    Args:
        ipc_path: Base path for zeromq interprocess messaging
        use_async_sockets: Whether to make send/recv async with GPU
        log_requests: Whether to log the requests.
68
69
        *args: Arguments for {class}`LLMEngine`.
        **kwargs: Arguments for {class}`LLMEngine`.
70
71
72
73
74
75
76
77
    """

    def __init__(self,
                 ipc_path: str,
                 use_async_sockets: bool,
                 *args,
                 log_requests: bool = True,
                 **kwargs) -> None:
78
79
80
        # For MQLLMEngine, we can use cached outputs, since each new request
        # output is immediately pickled and send over the socket, which frees
        # the python object to be reused again.
81
        kwargs['use_cached_outputs'] = True
82

83
        self.engine = LLMEngine(*args, **kwargs)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        self.log_requests = log_requests

        self.use_async_sockets = use_async_sockets
        if self.use_async_sockets:
            self.engine.process_request_outputs_callback = \
                self._async_socket_engine_callback

        self.ctx = zmq.Context()  # type: ignore[attr-defined]

        # Receive input from the client.
        self.input_socket = self.ctx.socket(zmq.constants.PULL)
        self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")

        # Send output stream back to client.
        self.output_socket = self.ctx.socket(zmq.constants.PUSH)
        self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")

101
102
103
        # Send heartbeats back to client.
        self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
        self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

        # IPC path for the data socket.
        self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

        # Error state.
        self._errored_with: Optional[BaseException] = None

    @property
    def dead_error(self) -> BaseException:
        if self._errored_with is not None:
            return ENGINE_DEAD_ERROR(self._errored_with)
        else:
            return ENGINE_DEAD_ERROR()

    @classmethod
119
120
121
122
    def from_vllm_config(cls, vllm_config: VllmConfig,
                         usage_context: UsageContext,
                         disable_log_requests: bool, disable_log_stats: bool,
                         ipc_path: str) -> "MQLLMEngine":
123
124
125
        # Setup plugins for each process
        from vllm.plugins import load_general_plugins
        load_general_plugins()
126

127
128
129
130
131
132
133
134
135
136
137
        use_async_sockets = vllm_config.model_config.use_async_output_proc

        return cls(
            vllm_config=vllm_config,
            executor_class=LLMEngine._get_executor_cls(vllm_config),
            ipc_path=ipc_path,
            usage_context=usage_context,
            use_async_sockets=use_async_sockets,
            log_requests=(not disable_log_requests),
            log_stats=(not disable_log_stats),
        )
138

139
140
141
142
    @staticmethod
    def from_engine_args(engine_args: AsyncEngineArgs,
                         usage_context: UsageContext, ipc_path: str):
        """Creates an MQLLMEngine from the engine arguments."""
143

144
145
146
147
148
149
150
151
        vllm_config = engine_args.create_engine_config(usage_context)
        return MQLLMEngine.from_vllm_config(
            ipc_path=ipc_path,
            vllm_config=vllm_config,
            usage_context=usage_context,
            disable_log_requests=engine_args.disable_log_requests,
            disable_log_stats=engine_args.disable_log_stats,
        )
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

    def start(self):
        try:
            try:
                logger.debug("Starting Startup Loop.")
                self.run_startup_loop()
                logger.debug("Starting Engine Loop.")
                self.run_engine_loop()
            except Exception as e:
                logger.exception(repr(e))
        except KeyboardInterrupt:
            logger.debug("Shutting down MQLLMEngine.")
        finally:
            logger.debug("MQLLMEngine is shut down.")
            self.cleanup()

    def cleanup(self):
        """Cleanup zeromq state on shutdown."""
        # Closes all sockets and destroys context.
        self.ctx.destroy(linger=0)
        del self.engine

    @contextmanager
    def make_data_socket(
            self) -> Iterator[zmq.Socket]:  # type: ignore[name-defined]
        socket = self.ctx.socket(zmq.constants.ROUTER)
        try:
            socket.bind(self.data_ipc_path)
            yield socket
        finally:
            socket.close(linger=0)

    def run_startup_loop(self) -> None:
        """Startup loop for sending data from Engine -> Client."""

        with self.make_data_socket() as socket:
            response: Union[RPCStartupResponse, BaseException]
            try:
                identity, message = socket.recv_multipart(copy=False)
                request: RPCStartupRequest = pickle.loads(message.buffer)

                # Handle the query from the Client.
                if request == RPCStartupRequest.IS_SERVER_READY:
                    tracing_enabled = self.engine.is_tracing_enabled()
                    response = RPCStartupResponse(
                        tracing_enabled=tracing_enabled)

            except Exception as e:
                response = e

            socket.send_multipart((identity, pickle.dumps(response)),
                                  copy=False)

    def run_engine_loop(self):
        """Core busy loop of the LLMEngine."""

        while True:
            if not self.engine.has_unfinished_requests():
                # Poll until there is work to do.
                while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
212
213
214
                    # When there's no work, check on engine health and send
                    # health status back to client
                    self._health_check()
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
                    self.engine.do_log_stats()
                    logger.debug("Waiting for new requests in engine loop.")

            # Handle any input from the client.
            self.handle_new_input()

            # Engine step.
            request_outputs = self.engine_step()

            # Send request outputs (if async, done in engine_step callback).
            if not self.use_async_sockets:
                self._send_outputs(request_outputs)

    def engine_step(self) -> List[RequestOutput]:
        """Engine step wrapper with error handling."""
        try:
            return self.engine.step()
        except SystemExit:
            raise
234
235
236
237
238
239
240
241
        except InputProcessingError as e:
            # Special case where we handle an error preparing the inputs for
            # a single request in the batch
            rpc_err = RPCError(request_id=e.request_id,
                               is_engine_errored=False,
                               exception=e.__cause__)
            self._send_outputs(rpc_err)
            return []
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        except BaseException as e:
            self._set_errored(e)
            rpc_err = RPCError(request_id=None,
                               is_engine_errored=True,
                               exception=e)
            self._send_outputs(rpc_err)
            raise e

    def handle_new_input(self):
        """Handle new input from the socket"""
        try:
            while self.input_socket.poll(timeout=0) != 0:
                frames = self.input_socket.recv_multipart(copy=False)
                request = pickle.loads(frames[0].buffer)

257
                if isinstance(request, RPCProcessRequest):
258
259
                    if len(frames) > 1:
                        # Use cloudpickle for logits processors
260
                        assert isinstance(request.params, SamplingParams)
261
                        lprocs = cloudpickle.loads(frames[1].buffer)
262
263
                        request.params.logits_processors = lprocs
                    self._handle_process_request(request)
264
265
                elif isinstance(request, RPCAbortRequest):
                    self._handle_abort_request(request)
266
267
268
269
270
                elif isinstance(request, RPCUProfileRequest):
                    if request == RPCUProfileRequest.START_PROFILE:
                        self.start_profile()
                    else:
                        self.stop_profile()
271
272
                elif isinstance(request, RPCLoadAdapterRequest):
                    self._handle_load_adapter_request(request)
273
274
                elif isinstance(request, RPCResetMultiModalCacheRequest):
                    self.reset_mm_cache()
275
276
                elif isinstance(request, RPCResetPrefixCacheRequest):
                    self.reset_prefix_cache()
277
278
279
                elif isinstance(request, RPCSleepRequest):
                    self.sleep(request.value)
                elif isinstance(request, RPCWakeUpRequest):
280
                    self.wake_up(request.tags)
281
282
                elif isinstance(request, RPCIsSleepingRequest):
                    self._handle_is_sleeping_request(request)
283
                else:
284
285
                    raise ValueError("Unknown RPCRequest Type: "
                                     f"{type(request)}")
286
287
288
289

        except Exception as e:
            self._set_errored(e)
            self._send_unhealthy(e)
290
            raise e from None
291

292
293
    def _handle_process_request(self, request: RPCProcessRequest):
        """Handle RPCProcessRequest by adding it to the LLMEngine."""
294
295
296
297
298
299
300
301
302
303
304
        request_id = request.request_id

        if self._errored_with is not None:
            rpc_err = RPCError(request_id=request_id,
                               is_engine_errored=True,
                               exception=ENGINE_DEAD_ERROR(self._errored_with))
            self._send_outputs(rpc_err)

        try:
            self.engine.add_request(
                request_id=request_id,
305
                prompt=request.prompt,
306
                params=request.params,
307
308
                lora_request=request.lora_request,
                trace_headers=request.trace_headers,
309
310
                prompt_adapter_request=request.prompt_adapter_request,
                priority=request.priority)
311
312
313
314
315
316
317
318

            if self.log_requests:
                logger.info("Added request %s.", request.request_id)

        except Exception as e:
            # We do not set self._errored = True here, since the error
            # is due to an issue adding this request to the engine,
            # rather than an issue with the engine itself.
319
320
            logger.debug("Failed to add request %s to engine. %s",
                         request.request_id, e)
321
322
323
324
325
326
327
328
329
330
331
332
333
334
            is_errored = self._errored_with is not None
            rpc_err = RPCError(request_id=request_id,
                               is_engine_errored=is_errored,
                               exception=e)
            self._send_outputs(rpc_err)

            # Remove request from the engine.
            self.engine.abort_request(request_id)

    def _handle_abort_request(self, request: RPCAbortRequest):
        self.engine.abort_request(request.request_id)
        if self.log_requests:
            logger.info("Aborted request %s.", request.request_id)

335
336
337
338
339
340
341
342
343
    def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
        try:
            self.engine.add_lora(request.lora_request)
        except BaseException as e:
            # Send back an error if the adater fails to load
            rpc_err = RPCError(request_id=request.request_id,
                               is_engine_errored=False,
                               exception=e)
            self._send_outputs(rpc_err)
344
            return
345
346
347
348
        # Otherwise, send back the successful load message
        self._send_outputs(
            RPCAdapterLoadedResponse(request_id=request.request_id))

349
350
351
352
353
354
    def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
        is_sleeping = self.is_sleeping()
        self._send_outputs(
            RPCIsSleepingResponse(request_id=request.request_id,
                                  is_sleeping=is_sleeping))

355
    def _health_check(self):
356
        # Send unhealthy if engine has already errored
357
358
        if self._errored_with is not None:
            self._send_unhealthy(self._errored_with)
359
360
361
362
363
364
        try:
            self.engine.check_health()
            self._send_healthy()
        except Exception as e:
            self._set_errored(e)
            self._send_unhealthy(e)
365
366

    def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
367
368
369
370
371
        """Send outputs back to the engine client. These can be:
        - Exceptions
        - A list of generation outputs
        - A response from loading a lora adapter
        """
372
        if outputs:
373
374
375
376
377
378
379
380
381
382
383
            try:
                from ray.exceptions import RayTaskError

                # RayTaskError might not pickelable here. We need to unpack the
                # underlying exception as the real exception in the output.
                if (isinstance(outputs, RPCError)
                        and isinstance(outputs.exception, RayTaskError)):
                    outputs.exception = outputs.exception.cause
            except ImportError:
                pass

384
385
386
387
388
            output_bytes = pickle.dumps(outputs)
            self.output_socket.send_multipart((output_bytes, ), copy=False)

    def _send_healthy(self):
        """Send HEALTHY message to RPCClient."""
389
390
        if not self.heartbeat_socket.closed:
            self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
391
392
393

    def _send_unhealthy(self, error: BaseException):
        """Send UNHEALTHY message to RPCClient."""
394
395
396
        if not self.heartbeat_socket.closed:
            error_bytes = pickle.dumps(error)
            self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
397
398
399
400
401
402
403
404
405
406
407
408

    def _async_socket_engine_callback(self,
                                      request_outputs: REQUEST_OUTPUTS_T):
        """Callback used by engine to make socket handling async with GPU."""
        self._send_outputs(request_outputs)
        self.handle_new_input()

    def _set_errored(self, e: BaseException):
        """Log and set errored status if this is the first issue."""
        if self._errored_with is None:
            self._errored_with = e

409
    def start_profile(self) -> None:
410
        self.engine.start_profile()
411
412

    def stop_profile(self) -> None:
413
        self.engine.stop_profile()
414

415
416
417
    def reset_mm_cache(self) -> bool:
        return self.engine.reset_mm_cache()

418
419
420
    def reset_prefix_cache(self) -> bool:
        return self.engine.reset_prefix_cache()

421
422
423
    def sleep(self, level: int = 1) -> None:
        self.engine.sleep(level)

424
425
    def wake_up(self, tags: Optional[list[str]] = None) -> None:
        self.engine.wake_up(tags)
426

427
428
429
    def is_sleeping(self) -> bool:
        return self.engine.is_sleeping()

430

431
432
433
434
def signal_handler(*_) -> None:
    raise KeyboardInterrupt("MQLLMEngine terminated")


435
436
437
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
                  ipc_path: str, disable_log_stats: bool,
                  disable_log_requests: bool, engine_alive):
438
    try:
439
440
441
        # Ensure we can serialize transformer config before spawning
        maybe_register_config_serialize_by_value()

442
443
444
445
446
447
        engine = MQLLMEngine.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            disable_log_stats=disable_log_stats,
            disable_log_requests=disable_log_requests,
            ipc_path=ipc_path)
448

449
        signal.signal(signal.SIGTERM, signal_handler)
450

451
        engine.start()
452

453
454
455
    except BaseException as e:
        logger.exception(e)
        engine_alive.value = False
456
        raise e from None