engine.py 15.4 KB
Newer Older
1
2
import pickle
import signal
3
4
import threading
import time
5
6
7
8
9
10
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

import cloudpickle
import zmq

11
from vllm import AsyncEngineArgs, SamplingParams
12
13
14
15
16
17
18
19
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
                                         IPC_HEALTH_EXT, IPC_INPUT_EXT,
                                         IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
                                         VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
20
                                         RPCError, RPCProcessRequest,
21
22
                                         RPCStartupRequest, RPCStartupResponse,
                                         RPCUProfileRequest)
23
# yapf: enable
24
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
25
from vllm.executor.gpu_executor import GPUExecutor
26
27
28
29
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext

30
31
32
33
34
if VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine
else:
    from vllm.engine.llm_engine import LLMEngine

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
                    SchedulerConfig, LoRAConfig]

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )


class MQLLMEngine:
    """A multiprocessing wrapper for :class:`LLMEngine`.

    This class is used to wrap the :class:`LLMEngine` class to enable use
    in concurrnet manner. It runs a background loop and uses zeromq to 
    receive new requests and stream outputs incrementally via ipc.
    
51
52
    The :class:`LLMEngine` generate or encode process is kicked off when a new
    RPCProcessRequest is received by the input_socket.
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    
    The self.engine_loop checks the input_socket for new requests,
    adds them to the LLMEngine if there are any, calls the internal
    :class:`LLMEngine.step()`, and sends the RequestOutputs back over
    the output_socket.

    If use_async_sockets is set, the logic associated with reading new
    requests from the socket and sending data to the socket is passed
    as a callback to the llm_engine, which calls the logic asynchronously
    such that the IPC can be overlapped with the GPU.

    Args:
        ipc_path: Base path for zeromq interprocess messaging
        use_async_sockets: Whether to make send/recv async with GPU
        log_requests: Whether to log the requests.
        *args: Arguments for :class:`LLMEngine`.
        **kwargs: Arguments for :class:`LLMEngine`.
    """

    def __init__(self,
                 ipc_path: str,
                 use_async_sockets: bool,
                 *args,
                 log_requests: bool = True,
                 **kwargs) -> None:
78
79
80
        # For MQLLMEngine, we can use cached outputs, since each new request
        # output is immediately pickled and send over the socket, which frees
        # the python object to be reused again.
81
        kwargs['use_cached_outputs'] = True
82

83
        self.engine = LLMEngine(*args, **kwargs)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        self.log_requests = log_requests

        self.use_async_sockets = use_async_sockets
        if self.use_async_sockets:
            self.engine.process_request_outputs_callback = \
                self._async_socket_engine_callback

        self.ctx = zmq.Context()  # type: ignore[attr-defined]

        # Receive input from the client.
        self.input_socket = self.ctx.socket(zmq.constants.PULL)
        self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")

        # Send output stream back to client.
        self.output_socket = self.ctx.socket(zmq.constants.PUSH)
        self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")

101
102
103
        # Send heartbeats back to client.
        self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
        self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
104
105
106
107
108
109
110

        # IPC path for the data socket.
        self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

        # Error state.
        self._errored_with: Optional[BaseException] = None

111
112
113
114
115
116
117
118
119
120
121
122
123
124
        # Heartbeat thread
        self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
                                                 daemon=True)
        self._heartbeat_stop_event = threading.Event()
        # The heartbeat needs to be faster than what the client will wait for
        # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
        self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0

        self._last_alive_time = time.time()
        # The heartbeats can tolerate a long period of the engine chugging
        # away at a generation request.
        # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
        self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0

125
126
127
128
129
130
131
132
133
134
135
    @property
    def dead_error(self) -> BaseException:
        if self._errored_with is not None:
            return ENGINE_DEAD_ERROR(self._errored_with)
        else:
            return ENGINE_DEAD_ERROR()

    @classmethod
    def from_engine_args(cls, engine_args: AsyncEngineArgs,
                         usage_context: UsageContext, ipc_path: str):
        """Creates an MQLLMEngine from the engine arguments."""
136
137
138
        # Setup plugins for each process
        from vllm.plugins import load_general_plugins
        load_general_plugins()
139
140
141
142
143

        engine_config = engine_args.create_engine_config()

        executor_class = LLMEngine._get_executor_cls(engine_config)

144
145
146
147
148
149
150
151
152
153
        use_async_sockets = (engine_config.model_config.use_async_output_proc
                             and not VLLM_USE_V1)

        return cls(ipc_path=ipc_path,
                   use_async_sockets=use_async_sockets,
                   **engine_config.to_dict(),
                   executor_class=executor_class,
                   log_requests=not engine_args.disable_log_requests,
                   log_stats=not engine_args.disable_log_stats,
                   usage_context=usage_context)
154
155
156
157
158
159

    def start(self):
        try:
            try:
                logger.debug("Starting Startup Loop.")
                self.run_startup_loop()
160
161
                logger.debug("Starting heartbeat thread")
                self.heartbeat_thread.start()
162
163
164
165
166
167
168
169
170
171
172
173
174
                logger.debug("Starting Engine Loop.")
                self.run_engine_loop()
            except Exception as e:
                logger.exception(repr(e))
        except KeyboardInterrupt:
            logger.debug("Shutting down MQLLMEngine.")
        finally:
            logger.debug("MQLLMEngine is shut down.")
            self.cleanup()

    def cleanup(self):
        """Cleanup zeromq state on shutdown."""
        # Closes all sockets and destroys context.
175
        self._heartbeat_stop_event.set()
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        self.ctx.destroy(linger=0)
        del self.engine

    @contextmanager
    def make_data_socket(
            self) -> Iterator[zmq.Socket]:  # type: ignore[name-defined]
        socket = self.ctx.socket(zmq.constants.ROUTER)
        try:
            socket.bind(self.data_ipc_path)
            yield socket
        finally:
            socket.close(linger=0)

    def run_startup_loop(self) -> None:
        """Startup loop for sending data from Engine -> Client."""

        with self.make_data_socket() as socket:
            response: Union[RPCStartupResponse, BaseException]
            try:
                identity, message = socket.recv_multipart(copy=False)
                request: RPCStartupRequest = pickle.loads(message.buffer)

                # Handle the query from the Client.
                if request == RPCStartupRequest.IS_SERVER_READY:
                    tracing_enabled = self.engine.is_tracing_enabled()
                    response = RPCStartupResponse(
                        tracing_enabled=tracing_enabled)

            except Exception as e:
                response = e

            socket.send_multipart((identity, pickle.dumps(response)),
                                  copy=False)

    def run_engine_loop(self):
        """Core busy loop of the LLMEngine."""

        while True:
214
            self._alive()
215
216
217
            if not self.engine.has_unfinished_requests():
                # Poll until there is work to do.
                while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
218
                    self._alive()
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                    self.engine.do_log_stats()
                    logger.debug("Waiting for new requests in engine loop.")

            # Handle any input from the client.
            self.handle_new_input()

            # Engine step.
            request_outputs = self.engine_step()

            # Send request outputs (if async, done in engine_step callback).
            if not self.use_async_sockets:
                self._send_outputs(request_outputs)

    def engine_step(self) -> List[RequestOutput]:
        """Engine step wrapper with error handling."""
        try:
            return self.engine.step()
        except SystemExit:
            raise
        except BaseException as e:
            self._set_errored(e)
            rpc_err = RPCError(request_id=None,
                               is_engine_errored=True,
                               exception=e)
            self._send_outputs(rpc_err)
            raise e

    def handle_new_input(self):
        """Handle new input from the socket"""
        try:
            while self.input_socket.poll(timeout=0) != 0:
                frames = self.input_socket.recv_multipart(copy=False)
                request = pickle.loads(frames[0].buffer)

253
                if isinstance(request, RPCProcessRequest):
254
255
                    if len(frames) > 1:
                        # Use cloudpickle for logits processors
256
                        assert isinstance(request.params, SamplingParams)
257
                        lprocs = cloudpickle.loads(frames[1].buffer)
258
259
                        request.params.logits_processors = lprocs
                    self._handle_process_request(request)
260
261
                elif isinstance(request, RPCAbortRequest):
                    self._handle_abort_request(request)
262
263
264
265
266
                elif isinstance(request, RPCUProfileRequest):
                    if request == RPCUProfileRequest.START_PROFILE:
                        self.start_profile()
                    else:
                        self.stop_profile()
267
                else:
268
269
                    raise ValueError("Unknown RPCRequest Type: "
                                     f"{type(request)}")
270
271
272
273
274
275

        except Exception as e:
            self._set_errored(e)
            self._send_unhealthy(e)
            raise e

276
277
    def _handle_process_request(self, request: RPCProcessRequest):
        """Handle RPCProcessRequest by adding it to the LLMEngine."""
278
279
280
281
282
283
284
285
286
287
288
        request_id = request.request_id

        if self._errored_with is not None:
            rpc_err = RPCError(request_id=request_id,
                               is_engine_errored=True,
                               exception=ENGINE_DEAD_ERROR(self._errored_with))
            self._send_outputs(rpc_err)

        try:
            self.engine.add_request(
                request_id=request_id,
289
                prompt=request.prompt,
290
                params=request.params,
291
292
                lora_request=request.lora_request,
                trace_headers=request.trace_headers,
293
294
                prompt_adapter_request=request.prompt_adapter_request,
                priority=request.priority)
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

            if self.log_requests:
                logger.info("Added request %s.", request.request_id)

        except Exception as e:
            # We do not set self._errored = True here, since the error
            # is due to an issue adding this request to the engine,
            # rather than an issue with the engine itself.
            is_errored = self._errored_with is not None
            rpc_err = RPCError(request_id=request_id,
                               is_engine_errored=is_errored,
                               exception=e)
            self._send_outputs(rpc_err)

            # Remove request from the engine.
            self.engine.abort_request(request_id)

    def _handle_abort_request(self, request: RPCAbortRequest):
        self.engine.abort_request(request.request_id)
        if self.log_requests:
            logger.info("Aborted request %s.", request.request_id)

317
318
319
320
321
322
323
324
325
326
    def _heartbeat_loop(self):
        while not self._heartbeat_stop_event.wait(
                timeout=self.heartbeat_interval_seconds):
            # Loops until the stop event is set
            self._heartbeat()

        logger.debug("Exiting MQLLMEngine heartbeat thread")

    def _heartbeat(self):
        # Send unhealthy if engine has already errored
327
328
329
        if self._errored_with is not None:
            self._send_unhealthy(self._errored_with)

330
331
332
333
334
335
336
337
338
339
340
341
342
        # Check for life of the main loop
        elif time.time() - self._last_alive_time > self.last_alive_threshold:
            self._send_unhealthy(RuntimeError("Engine loop has died"))

        else:
            # Otherwise- check health of the engine
            # self.engine.check_health() raises on unhealthy
            try:
                self.engine.check_health()
                self._send_healthy()
            except Exception as e:
                self._set_errored(e)
                self._send_unhealthy(e)
343
344
345
346
347
348
349
350
351

    def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
        """Send List of RequestOutput to RPCClient."""
        if outputs:
            output_bytes = pickle.dumps(outputs)
            self.output_socket.send_multipart((output_bytes, ), copy=False)

    def _send_healthy(self):
        """Send HEALTHY message to RPCClient."""
352
353
        if not self.heartbeat_socket.closed:
            self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
354
355
356

    def _send_unhealthy(self, error: BaseException):
        """Send UNHEALTHY message to RPCClient."""
357
358
359
        if not self.heartbeat_socket.closed:
            error_bytes = pickle.dumps(error)
            self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
360
361
362
363
364
365
366
367
368
369
370
371

    def _async_socket_engine_callback(self,
                                      request_outputs: REQUEST_OUTPUTS_T):
        """Callback used by engine to make socket handling async with GPU."""
        self._send_outputs(request_outputs)
        self.handle_new_input()

    def _set_errored(self, e: BaseException):
        """Log and set errored status if this is the first issue."""
        if self._errored_with is None:
            self._errored_with = e

372
373
374
    def _alive(self):
        self._last_alive_time = time.time()

375
376
377
378
379
380
381
382
383
384
385
386
    def start_profile(self) -> None:
        if type(self.engine.model_executor) is GPUExecutor:
            self.engine.model_executor.start_profile()
        else:
            self.engine.model_executor._run_workers("start_profile")

    def stop_profile(self) -> None:
        if type(self.engine.model_executor) is GPUExecutor:
            self.engine.model_executor.stop_profile()
        else:
            self.engine.model_executor._run_workers("stop_profile")

387
388
389
390
391
392
393
394
395
396
397
398
399
400

def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
                  ipc_path: str):

    def signal_handler(*_) -> None:
        # Interrupt server on sigterm
        raise KeyboardInterrupt("MQLLMEngine terminated")

    signal.signal(signal.SIGTERM, signal_handler)

    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
                                          usage_context=usage_context,
                                          ipc_path=ipc_path)
    engine.start()