client.py 11 KB
Newer Older
1
from contextlib import contextmanager
2
from typing import Any, AsyncGenerator, Mapping, Optional
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

import cloudpickle
import zmq
import zmq.asyncio

from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
                                         VLLM_RPC_HEALTHY_STR,
                                         VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
                                         RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs

21
22
23
# Time to wait before checking it the server process is alive.
SERVER_START_TIMEOUT_MS = 1000

24
25
26

class AsyncEngineRPCClient:

27
    def __init__(self, rpc_path: str):
28
        self.context = zmq.asyncio.Context()
29
        self.rpc_path = rpc_path
30
31
32
33
34
35

    async def setup(self):
        """Setup the client before it starts sending server requests."""

        # Wait until server is ready.
        await self.wait_for_server()
36
        self._errored = False
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

        # Get the configs.
        self.model_config = await self._get_model_config_rpc()
        self.decoding_config = await self._get_decoding_config_rpc()
        self.tracing_flag = await self._is_tracing_enabled_rpc()

        # Create the tokenizer group.
        # TODO: refactor OAI server to avoid needing this info.
        self.tokenizer = init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=(await self._get_scheduler_config_rpc()),
            parallel_config=(await self._get_parallel_config_rpc()),
            enable_lora=bool(await self._get_lora_config_rpc()),
        )

    def close(self):
        """Destroy the ZeroMQ Context."""
        self.context.destroy()

    @contextmanager
    def socket(self):
        # Ensure client sockets are always closed after use

        # Connect to RPC socket for Request-Reply pattern,
        # Note that we use DEALER to enable asynchronous communication
        # to enable streaming.
        socket = self.context.socket(zmq.constants.DEALER)
        try:
65
            socket.connect(self.rpc_path)
66
67
            yield socket
        finally:
68
69
70
71
72
73
74
75
76
77
            # linger == 0 means discard unsent messages
            # when the socket is closed. This is necessary
            # because otherwise self.context.destroy() will
            # wait for 30 seconds until unsent messages are
            # received, which is impossible if the server
            # crashed. In the absence of a server crash we
            # always expect a response before closing the
            # socket anyway.
            # Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
            socket.close(linger=0)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
                                         expected_type: Any,
                                         error_message: str) -> Any:
        """Send an RPC request that is expecting data back."""

        with self.socket() as socket:

            # Ping RPCServer with a request.
            await socket.send(cloudpickle.dumps(request))

            # Await the data from the Server.
            data = cloudpickle.loads(await socket.recv())

        if not isinstance(data, expected_type):
            # LoRAConfig can be None.
            if expected_type == LoRAConfig and data is None:
                pass
            else:
                raise ValueError(error_message)

        return data

101
102
103
104
    async def _send_one_way_rpc_request(self,
                                        request: RPC_REQUEST_TYPE,
                                        error_message: str,
                                        timeout: Optional[int] = None):
105
106
107
108
109
110
        """Send one-way RPC request to trigger an action."""
        with self.socket() as socket:
            # Ping RPC Server with request.
            await socket.send(cloudpickle.dumps(request))

            # Await acknowledgement from RPCServer.
111
112
113
            if timeout is not None and await socket.poll(timeout=timeout) == 0:
                raise TimeoutError(f"server didn't reply within {timeout} ms")

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
            response = cloudpickle.loads(await socket.recv())

        if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
            raise ValueError(error_message)

        return response

    async def get_tokenizer(self, lora_request: LoRARequest):
        return await self.tokenizer.get_lora_tokenizer_async(lora_request)

    async def get_decoding_config(self) -> DecodingConfig:
        return self.decoding_config

    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def is_tracing_enabled(self) -> bool:
        return self.tracing_flag

    async def wait_for_server(self):
        """Wait for the RPCServer to start up."""

        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.IS_SERVER_READY,
138
139
            error_message="Unable to start RPC Server.",
            timeout=SERVER_START_TIMEOUT_MS)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

    async def _get_model_config_rpc(self) -> ModelConfig:
        """Get the ModelConfig object from the RPC Server"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_MODEL_CONFIG,
            expected_type=ModelConfig,
            error_message="Could not get ModelConfig from RPC Server")

    async def _get_decoding_config_rpc(self) -> DecodingConfig:
        """Get DecodingConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_DECODING_CONFIG,
            expected_type=DecodingConfig,
            error_message="Could not get DecodingConfig from RPC Server")

    async def _get_parallel_config_rpc(self) -> ParallelConfig:
        """Get ParallelConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_PARALLEL_CONFIG,
            expected_type=ParallelConfig,
            error_message="Could not get ParallelConfig from RPC Server")

    async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
        """Get SchedulerConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_SCHEDULER_CONFIG,
            expected_type=SchedulerConfig,
            error_message="Could not get SchedulerConfig from RPC Server")

173
    async def _get_lora_config_rpc(self) -> LoRAConfig:
174
175
176
177
178
179
180
        """Get LoRAConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_LORA_CONFIG,
            expected_type=LoRAConfig,
            error_message="Could not get LoRAConfig from RPC Server")

181
    async def _is_tracing_enabled_rpc(self) -> bool:
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        """Get is_tracing_enabled flag from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.IS_TRACING_ENABLED,
            expected_type=bool,
            error_message="Could not get is_tracing_enabled flag from RPC "
            "Server")

    async def abort(self, request_id: str):
        """Send an ABORT_REQUEST signal to the RPC Server"""

        await self._send_one_way_rpc_request(
            request=RPCAbortRequest(request_id),
            error_message=f"RPCAbortRequest {request_id} failed")

    async def do_log_stats(self):
        """Send a DO_LOG_STATS signal to the RPC Server"""

        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.DO_LOG_STATS,
            error_message="RPCRequest DO_LOG_STATS failed.")

204
205
206
207
208
209
210
211
212
213
214
215
    @property
    def is_running(self) -> bool:
        return not self._errored

    @property
    def is_stopped(self) -> bool:
        return self._errored

    @property
    def errored(self) -> bool:
        return self._errored

216
217
218
219
220
221
222
223
    async def generate(
        self,
        inputs: PromptInputs,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
224
    ) -> AsyncGenerator[RequestOutput, None]:
225
226
        """Send an RPCGenerateRequest to the RPCServer and stream responses."""

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        finished = False
        try:
            with self.socket() as socket:

                # Send RPCGenerateRequest to the RPCServer.
                await socket.send_multipart([
                    cloudpickle.dumps(
                        RPCGenerateRequest(
                            inputs=inputs,
                            sampling_params=sampling_params,
                            request_id=request_id,
                            lora_request=lora_request,
                            trace_headers=trace_headers,
                            prompt_adapter_request=prompt_adapter_request))
                ])

                # Stream back the results from the RPC Server.
                while not finished:
                    message = await socket.recv()
                    request_output = cloudpickle.loads(message)

                    if isinstance(request_output, Exception):
249
250
251
252
253
254
255
256
257
                        # On exception, check if the server is still healthy.
                        # Use this to set the sync `is_running` and `errored`
                        # properties.
                        try:
                            await self.check_health()
                        except Exception:
                            self._errored = True
                        # NB: do before raising here so that the flag is set
                        # by the time the caller receives this exception
258
259
260
261
262
263
264
                        raise request_output

                    finished = request_output.finished
                    yield request_output
        finally:
            if not finished:
                await self.abort(request_id)
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

    async def check_health(self) -> None:
        """Raise if unhealthy"""

        with self.socket() as socket:

            # Ping RPCServer with CHECK_HEALTH request.
            await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
                              )

            # Await the reply from the server.
            # TODO: do we need an internal timeout here?
            # Or do we expect the external probe to timeout and let this chill?
            health_message = cloudpickle.loads(await socket.recv())

        if isinstance(health_message, Exception):
            raise health_message

        if health_message != VLLM_RPC_HEALTHY_STR:
            raise ValueError("Expected healthy response from backend but got "
                             "f{health_message}")

    async def encode(self, *args,
288
                     **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
289
290
        raise NotImplementedError(
            "Embeddings not supported with multiprocessing backend")