client.py 18.7 KB
Newer Older
1
import asyncio
2
from contextlib import contextmanager, suppress
3
from typing import Any, AsyncGenerator, Mapping, Optional
4
from uuid import uuid4
5
6
7
8
9
10
11

import cloudpickle
import zmq
import zmq.asyncio

from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
12
# yapf: disable
13
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
14
15
16
                                         VLLM_RPC_SOCKET_LIMIT_CUTOFF,
                                         VLLM_RPC_SUCCESS_STR,
                                         VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
17
                                         RPCGenerateRequest, RPCUtilityRequest)
18
# yapf: enable
19
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
20
from vllm.inputs import PromptInputs
21
from vllm.logger import init_logger
22
23
24
25
26
27
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs

28
29
30
31
logger = init_logger(__name__)

# Path used for inprocess proxy.
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
32

33

34
35
36
37
38
39
40
41
42
43
44
class RPCClientClosedError(Exception):
    """Exception class raised when the client is used post-close.
    
    The client can be closed, which closes the ZMQ context. This normally
    happens on server shutdown. In some cases, methods like abort and 
    do_log_stats will still be called and then try to open a socket, which 
    causes a ZMQError and creates a huge stack trace.
    So, we throw this error such that we can suppress it.
    """


45
class AsyncEngineRPCClient:
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    """
    RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
    
    The overall design mirrors the Asynchronous Client Server Pattern
    https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern

    On startup, the RPCClient:
        - makes DEALER socket (to_rpc_server) that connects to the RPCServer 
            via ipc, which uses unix sockets under the hood
            (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
        - makes ROUTER socket (from_api_server) that binds to a random 
            inproc address, which uses memory under the hood
            (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
        - runs a proxy in a background asyncio task between 
            from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )

    Each request handled by the asyncio api_server calls generate():
        - make a DEALER socket that connects to from_api_server via inproc
        - send a RCPGenerateRequest to the inproc socket
        - background proxy forwards the request from inproc -> ipc
        - RPCServer responds to the request one token at a time over ipc
        - background proxy forwards the response from ipc -> inproc

    The connection looks like this:
        DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
    
    Message routing is performed via identities that are managed by the 
    ROUTER socket. ROUTER sockets track every connection it has and 
    tells the caller about these. The way it tells the caller is to stick 
    the connection identity in front of each message received. When we 
    send the message via a ROUTER, we first send an identity frame.
    See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
    for more details on connection identities.

    This proxy design enables us to use a single unix socket, which 
    improves performance by avoiding syscalls (~5%) and avoids resource limits
    such as ulimit, which defaults to 1024 on ubuntu.

    Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
    which is required to avoid dropping messages under high load. 
    This is generally not advisable. However, since we are in control
    of both sides of the connection + failure on either side is
    catastrophic to the overall system health and memory profiling
    suggests limited memory overhead relative to asyncio, we will 
    proceed for now.

    See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks 
    for more details on high water marks.
    """
95

96
    def __init__(self, rpc_path: str):
97
        self.context = zmq.asyncio.Context()
98
99
        self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
        self._errored = False
100
101
102
103

        # Maximum number of sockets that can be opened (typically 65536).
        # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
        socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
104
        assert isinstance(socket_limit, int)
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
            raise ValueError(
                f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
                "the number of concurrent requests vLLM can process. Launch "
                "vLLM with --disable-frontend-multiprocessing and open a "
                "GitHub issue so we can investigate.")

        # We only have 1 ipc connection that uses unix sockets, so
        # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
        # not run into ulimit issues)
        self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)

        # IPC connection to RPC Server (uses unix sockets).
        self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
        self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
        self.to_rpc_server.bind(rpc_path)

        # In process proxy to RPC Server (uses memory-based messaging).
        self.from_api_server = self.context.socket(zmq.constants.ROUTER)
        self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
        self.from_api_server.bind(INPROC_PROXY_PATH)

        # Asyncio background task for the proxy.
        self.proxy_task = asyncio.create_task(
            self.run_proxy(self.from_api_server, self.to_rpc_server))

        # Since we open 1 inproc socket per request, we have a hard cap on
        # the number of requests that can run in vLLM w. frontend
        # mulitprocessing. This value is used uvicorn to launch
        # with --limit-concurrency to return 503 when server is overloaded.
        # We need 2 sockets per request - 2:
        # 1 for generate(), 1 for abort(), do_log_stats(), check_health()
        self.limit_concurrency = socket_limit // 2 - 2

    async def run_proxy(self, socket_from, socket_to):
        """Background task that runs a proxy"""
        poller = zmq.asyncio.Poller()
        poller.register(socket_from, zmq.constants.POLLIN)
        poller.register(socket_to, zmq.constants.POLLIN)
        while True:
145
146
            events_lst = await poller.poll()
            events = dict(events_lst)
147
148
149
150
151
152
            if socket_from in events:
                identity, msg = await socket_from.recv_multipart()
                await socket_to.send_multipart([identity, msg])
            if socket_to in events:
                identity, msg = await socket_to.recv_multipart()
                await socket_from.send_multipart([identity, msg])
153
154
155
156
157

    async def setup(self):
        """Setup the client before it starts sending server requests."""

        # Wait until server is ready.
158
        await self._wait_for_server_rpc()
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

        # Get the configs.
        self.model_config = await self._get_model_config_rpc()
        self.decoding_config = await self._get_decoding_config_rpc()
        self.tracing_flag = await self._is_tracing_enabled_rpc()

        # Create the tokenizer group.
        # TODO: refactor OAI server to avoid needing this info.
        self.tokenizer = init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=(await self._get_scheduler_config_rpc()),
            parallel_config=(await self._get_parallel_config_rpc()),
            enable_lora=bool(await self._get_lora_config_rpc()),
        )

    def close(self):
        """Destroy the ZeroMQ Context."""
176
177
178
179
        # Close all sockets associated with this context and
        # then terminate the context.
        self.from_api_server.close()
        self.to_rpc_server.close()
180
181
182
        self.context.destroy()

    @contextmanager
183
184
    def to_proxy_socket(self):
        # Connect to the RPCServer via the proxy.
185
186
187
188
189
190
191
192
193

        # Raise a sensible error if the client was already closed.
        # This can happen if a server shutdown is triggered but some coroutines
        # are still running requests.
        # There should not be a race condition with this check because we don't
        # yield to the event loop between here and opening the socket.
        if self.context.closed:
            raise RPCClientClosedError("The ZMQ client has already shut down")

194
195
196
        # Note that we use DEALER to enable asynchronous communication
        # to enable streaming.
        socket = self.context.socket(zmq.constants.DEALER)
197
        socket.set_hwm(VLLM_RPC_ZMQ_HWM)
198
        try:
199
            socket.connect(INPROC_PROXY_PATH)
200
201
            yield socket
        finally:
202
            socket.close(linger=0)
203
204
205
206
207
208

    async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
                                         expected_type: Any,
                                         error_message: str) -> Any:
        """Send an RPC request that is expecting data back."""

209
        with self.to_proxy_socket() as socket:
210
            # Ping RPCServer with a request.
211
            await socket.send_multipart([cloudpickle.dumps(request)])
212

213
214
215
216
217
            # Make sure the server responds
            if await socket.poll(timeout=self._data_timeout) == 0:
                raise TimeoutError("Server didn't reply within "
                                   f"{self._data_timeout} ms")

218
219
220
            # Await the data from the Server.
            data = cloudpickle.loads(await socket.recv())

221
222
223
224
        if isinstance(data, Exception):
            # Re-raise exceptions returned by the server
            raise data

225
226
227
228
        if not isinstance(data, expected_type):
            # LoRAConfig can be None.
            if expected_type == LoRAConfig and data is None:
                pass
229
230
231
            elif isinstance(data, Exception):
                logger.error(error_message)
                raise data
232
233
234
235
236
            else:
                raise ValueError(error_message)

        return data

237
238
239
240
241
    async def _send_one_way_rpc_request(
            self,
            request: RPC_REQUEST_TYPE,
            error_message: str,
            socket: Optional[zmq.asyncio.Socket] = None):
242
243
        """Send one-way RPC request to trigger an action."""

244
        async def do_rpc_call(socket: zmq.asyncio.Socket,
245
                              request: RPC_REQUEST_TYPE):
246
247
248

            await socket.send_multipart([cloudpickle.dumps(request)])

249
250
251
            if await socket.poll(timeout=self._data_timeout) == 0:
                raise TimeoutError("Server didn't reply within "
                                   f"{self._data_timeout} ms")
252
253

            return cloudpickle.loads(await socket.recv())
254

255
256
257
        # Make a new socket connection.
        if socket is None:
            with self.to_proxy_socket() as socket:
258
                response = await do_rpc_call(socket, request)
259
260
261

        # Use existing socket connection.
        else:
262
            response = await do_rpc_call(socket, request)
263
264

        if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
265
266
267
            if isinstance(response, Exception):
                logger.error(error_message)
                raise response
268
269
270
271
272
273
274
275
276
277
278
279
280
281
            raise ValueError(error_message)

    async def get_tokenizer(self, lora_request: LoRARequest):
        return await self.tokenizer.get_lora_tokenizer_async(lora_request)

    async def get_decoding_config(self) -> DecodingConfig:
        return self.decoding_config

    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def is_tracing_enabled(self) -> bool:
        return self.tracing_flag

282
    async def _wait_for_server_rpc(self):
283
284
285
286
        """Wait for the RPCServer to start up."""

        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.IS_SERVER_READY,
287
            error_message="Unable to start RPC Server")
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

    async def _get_model_config_rpc(self) -> ModelConfig:
        """Get the ModelConfig object from the RPC Server"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_MODEL_CONFIG,
            expected_type=ModelConfig,
            error_message="Could not get ModelConfig from RPC Server")

    async def _get_decoding_config_rpc(self) -> DecodingConfig:
        """Get DecodingConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_DECODING_CONFIG,
            expected_type=DecodingConfig,
            error_message="Could not get DecodingConfig from RPC Server")

    async def _get_parallel_config_rpc(self) -> ParallelConfig:
        """Get ParallelConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_PARALLEL_CONFIG,
            expected_type=ParallelConfig,
            error_message="Could not get ParallelConfig from RPC Server")

    async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
        """Get SchedulerConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_SCHEDULER_CONFIG,
            expected_type=SchedulerConfig,
            error_message="Could not get SchedulerConfig from RPC Server")

321
    async def _get_lora_config_rpc(self) -> LoRAConfig:
322
323
324
325
326
327
328
        """Get LoRAConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_LORA_CONFIG,
            expected_type=LoRAConfig,
            error_message="Could not get LoRAConfig from RPC Server")

329
    async def _is_tracing_enabled_rpc(self) -> bool:
330
331
332
333
334
        """Get is_tracing_enabled flag from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.IS_TRACING_ENABLED,
            expected_type=bool,
335
            error_message="Could not get is_tracing_enabled from RPC Server")
336
337
338

    async def abort(self, request_id: str):
        """Send an ABORT_REQUEST signal to the RPC Server"""
339
340
341
342
343
344
345
346
347
348
349
350

        # Suppress timeouts as well.
        # In cases where the server is busy processing requests and a very
        # large volume of abort requests arrive, it is likely that the server
        # will not be able to ack all of them in time. We have seen this when
        # we abort 20k requests at once while another 2k are processing- many
        # of them time out, but we see the server successfully abort all of the
        # requests.
        # In this case we assume that the server has received or will receive
        # these abort requests, and ignore the timeout. This prevents a massive
        # wall of `TimeoutError` stack traces.
        with suppress(RPCClientClosedError, TimeoutError):
351
352
353
            await self._send_one_way_rpc_request(
                request=RPCAbortRequest(request_id),
                error_message=f"RPCAbortRequest {request_id} failed")
354
355
356

    async def do_log_stats(self):
        """Send a DO_LOG_STATS signal to the RPC Server"""
357
358
359
360
        with suppress(RPCClientClosedError):
            await self._send_one_way_rpc_request(
                request=RPCUtilityRequest.DO_LOG_STATS,
                error_message="RPCRequest DO_LOG_STATS failed.")
361

362
363
364
365
366
367
368
369
370
371
372
373
    @property
    def is_running(self) -> bool:
        return not self._errored

    @property
    def is_stopped(self) -> bool:
        return self._errored

    @property
    def errored(self) -> bool:
        return self._errored

374
375
376
377
378
379
380
381
    async def generate(
        self,
        inputs: PromptInputs,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
382
    ) -> AsyncGenerator[RequestOutput, None]:
383
384
        """Send an RPCGenerateRequest to the RPCServer and stream responses."""

385
386
        finished = False
        try:
387
            with self.to_proxy_socket() as socket:
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                # Send RPCGenerateRequest to the RPCServer.
                await socket.send_multipart([
                    cloudpickle.dumps(
                        RPCGenerateRequest(
                            inputs=inputs,
                            sampling_params=sampling_params,
                            request_id=request_id,
                            lora_request=lora_request,
                            trace_headers=trace_headers,
                            prompt_adapter_request=prompt_adapter_request))
                ])

                # Stream back the results from the RPC Server.
                while not finished:
                    message = await socket.recv()
                    request_output = cloudpickle.loads(message)

                    if isinstance(request_output, Exception):
406
407
408
409
410
411
412
413
414
                        # On exception, check if the server is still healthy
                        # possibly setting the `errored` property.
                        if not self._errored:
                            try:
                                await self.check_health(socket=socket)
                            except Exception as e:
                                self._errored = True
                                logger.exception(repr(e))

415
416
                        # NB: do before raising here so that the flag is set
                        # by the time the caller receives this exception
417
418
419
420
                        raise request_output

                    finished = request_output.finished
                    yield request_output
421

422
        finally:
423
424
            # Request was canceled by the client.
            if not finished and not self._errored:
425
                await self.abort(request_id)
426

427
428
429
    async def check_health(self,
                           socket: Optional[zmq.asyncio.Socket] = None
                           ) -> None:
430
431
        """Raise if unhealthy"""

432
433
434
435
        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.IS_SERVER_HEALTHY,
            error_message="Got Unhealthy response from RPC Server",
            socket=socket)
436
437

    async def encode(self, *args,
438
                     **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
439
440
        raise NotImplementedError(
            "Embeddings not supported with multiprocessing backend")
441
442
443
444
445
446
447
448
449
450
451
452
453
454

    async def start_profile(self) -> None:
        """Start profiling the engine"""

        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.START_PROFILE,
            error_message="RPCRequest START_PROFILE failed.")

    async def stop_profile(self) -> None:
        """Stop profiling the engine"""

        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.STOP_PROFILE,
            error_message="RPCRequest STOP_PROFILE failed.")