client.py 9.34 KB
Newer Older
1
from contextlib import contextmanager
2
from typing import Any, AsyncGenerator, Mapping, Optional
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

import cloudpickle
import zmq
import zmq.asyncio

from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
                                         VLLM_RPC_HEALTHY_STR,
                                         VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
                                         RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs


class AsyncEngineRPCClient:

    def __init__(self, port: int):
        self.context = zmq.asyncio.Context()
        self.path = f"tcp://localhost:{port}"

    async def setup(self):
        """Setup the client before it starts sending server requests."""

        # Wait until server is ready.
        await self.wait_for_server()

        # Get the configs.
        self.model_config = await self._get_model_config_rpc()
        self.decoding_config = await self._get_decoding_config_rpc()
        self.tracing_flag = await self._is_tracing_enabled_rpc()

        # Create the tokenizer group.
        # TODO: refactor OAI server to avoid needing this info.
        self.tokenizer = init_tokenizer_from_configs(
            model_config=self.model_config,
            scheduler_config=(await self._get_scheduler_config_rpc()),
            parallel_config=(await self._get_parallel_config_rpc()),
            enable_lora=bool(await self._get_lora_config_rpc()),
        )

    def close(self):
        """Destroy the ZeroMQ Context."""
        self.context.destroy()

    @contextmanager
    def socket(self):
        # Ensure client sockets are always closed after use

        # Connect to RPC socket for Request-Reply pattern,
        # Note that we use DEALER to enable asynchronous communication
        # to enable streaming.
        socket = self.context.socket(zmq.constants.DEALER)
        try:
            socket.connect(self.path)
            yield socket
        finally:
            socket.close()

    async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
                                         expected_type: Any,
                                         error_message: str) -> Any:
        """Send an RPC request that is expecting data back."""

        with self.socket() as socket:

            # Ping RPCServer with a request.
            await socket.send(cloudpickle.dumps(request))

            # Await the data from the Server.
            data = cloudpickle.loads(await socket.recv())

        if not isinstance(data, expected_type):
            # LoRAConfig can be None.
            if expected_type == LoRAConfig and data is None:
                pass
            else:
                raise ValueError(error_message)

        return data

    async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
                                        error_message: str):
        """Send one-way RPC request to trigger an action."""
        with self.socket() as socket:
            # Ping RPC Server with request.
            await socket.send(cloudpickle.dumps(request))

            # Await acknowledgement from RPCServer.
            response = cloudpickle.loads(await socket.recv())

        if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
            raise ValueError(error_message)

        return response

    async def get_tokenizer(self, lora_request: LoRARequest):
        return await self.tokenizer.get_lora_tokenizer_async(lora_request)

    async def get_decoding_config(self) -> DecodingConfig:
        return self.decoding_config

    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def is_tracing_enabled(self) -> bool:
        return self.tracing_flag

    async def wait_for_server(self):
        """Wait for the RPCServer to start up."""

        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.IS_SERVER_READY,
            error_message="Unable to start RPC Server.")

    async def _get_model_config_rpc(self) -> ModelConfig:
        """Get the ModelConfig object from the RPC Server"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_MODEL_CONFIG,
            expected_type=ModelConfig,
            error_message="Could not get ModelConfig from RPC Server")

    async def _get_decoding_config_rpc(self) -> DecodingConfig:
        """Get DecodingConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_DECODING_CONFIG,
            expected_type=DecodingConfig,
            error_message="Could not get DecodingConfig from RPC Server")

    async def _get_parallel_config_rpc(self) -> ParallelConfig:
        """Get ParallelConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_PARALLEL_CONFIG,
            expected_type=ParallelConfig,
            error_message="Could not get ParallelConfig from RPC Server")

    async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
        """Get SchedulerConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_SCHEDULER_CONFIG,
            expected_type=SchedulerConfig,
            error_message="Could not get SchedulerConfig from RPC Server")

    async def _get_lora_config_rpc(self):
        """Get LoRAConfig from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.GET_LORA_CONFIG,
            expected_type=LoRAConfig,
            error_message="Could not get LoRAConfig from RPC Server")

    async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
        """Get is_tracing_enabled flag from the RPCServer"""

        return await self._send_get_data_rpc_request(
            RPCUtilityRequest.IS_TRACING_ENABLED,
            expected_type=bool,
            error_message="Could not get is_tracing_enabled flag from RPC "
            "Server")

    async def abort(self, request_id: str):
        """Send an ABORT_REQUEST signal to the RPC Server"""

        await self._send_one_way_rpc_request(
            request=RPCAbortRequest(request_id),
            error_message=f"RPCAbortRequest {request_id} failed")

    async def do_log_stats(self):
        """Send a DO_LOG_STATS signal to the RPC Server"""

        await self._send_one_way_rpc_request(
            request=RPCUtilityRequest.DO_LOG_STATS,
            error_message="RPCRequest DO_LOG_STATS failed.")

    async def generate(
        self,
        inputs: PromptInputs,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None
193
    ) -> AsyncGenerator[RequestOutput, None]:
194
195
        """Send an RPCGenerateRequest to the RPCServer and stream responses."""

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        finished = False
        try:
            with self.socket() as socket:

                # Send RPCGenerateRequest to the RPCServer.
                await socket.send_multipart([
                    cloudpickle.dumps(
                        RPCGenerateRequest(
                            inputs=inputs,
                            sampling_params=sampling_params,
                            request_id=request_id,
                            lora_request=lora_request,
                            trace_headers=trace_headers,
                            prompt_adapter_request=prompt_adapter_request))
                ])

                # Stream back the results from the RPC Server.
                while not finished:
                    message = await socket.recv()
                    request_output = cloudpickle.loads(message)

                    if isinstance(request_output, Exception):
                        raise request_output

                    finished = request_output.finished
                    yield request_output
        finally:
            if not finished:
                await self.abort(request_id)
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    async def check_health(self) -> None:
        """Raise if unhealthy"""

        with self.socket() as socket:

            # Ping RPCServer with CHECK_HEALTH request.
            await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
                              )

            # Await the reply from the server.
            # TODO: do we need an internal timeout here?
            # Or do we expect the external probe to timeout and let this chill?
            health_message = cloudpickle.loads(await socket.recv())

        if isinstance(health_message, Exception):
            raise health_message

        if health_message != VLLM_RPC_HEALTHY_STR:
            raise ValueError("Expected healthy response from backend but got "
                             "f{health_message}")

    async def encode(self, *args,
248
                     **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
249
250
        raise NotImplementedError(
            "Embeddings not supported with multiprocessing backend")