server.py 9.17 KB
Newer Older
1
import asyncio
2
import pickle
3
import signal
4
from typing import Any, Coroutine, Union
5
6

import cloudpickle
7
import uvloop
8
9
10
import zmq
import zmq.asyncio
from typing_extensions import Never
11
12
from zmq import Frame  # type: ignore[attr-defined]
from zmq.asyncio import Socket
13
14

from vllm import AsyncEngineArgs, AsyncLLMEngine
15
16
17
18
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
                                         VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
19
20
21
22
23
24
                                         RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext

logger = init_logger(__name__)

25
26
27
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
                    SchedulerConfig, LoRAConfig]

28
29
30
31

class AsyncEngineRPCServer:

    def __init__(self, async_engine_args: AsyncEngineArgs,
32
                 usage_context: UsageContext, rpc_path: str):
33
        # Initialize engine first.
34
35
        self.engine = AsyncLLMEngine.from_engine_args(
            async_engine_args, usage_context=usage_context)
36
37
38
39

        # Initialize context.
        self.context = zmq.asyncio.Context()

40
        # Init socket.
41
        self.socket: Socket = self.context.socket(zmq.constants.DEALER)
42
43
        self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
        self.socket.connect(rpc_path)
44
45
46
47
48

    def cleanup(self):
        """Cleanup all resources."""
        self.socket.close()
        self.context.destroy()
49
        # Clear the engine reference so that it can be GC'ed.
50
        del self.engine
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    async def get_config(self, identity, request):
        try:
            config: CONFIG_TYPE
            if request == RPCUtilityRequest.GET_MODEL_CONFIG:
                config = await self.engine.get_model_config()
            elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
                config = await self.engine.get_decoding_config()
            elif request == RPCUtilityRequest.GET_LORA_CONFIG:
                config = await self.engine.get_lora_config()
            elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
                config = await self.engine.get_scheduler_config()
            elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
                config = await self.engine.get_parallel_config()
            else:
                raise ValueError("Unknown Config Request: %s", request)
67

68
69
            await self.socket.send_multipart((identity, pickle.dumps(config)),
                                             copy=False)
70

71
        except Exception as e:
72
73
            await self.socket.send_multipart((identity, pickle.dumps(e)),
                                             copy=False)
74
75
76
77
78
79

    async def is_tracing_enabled(self, identity):
        """Send the is_tracing_enabled flag"""
        tracing_flag = await self.engine.is_tracing_enabled()

        await self.socket.send_multipart(
80
            (identity, pickle.dumps(tracing_flag)))
81
82
83
84
85

    async def do_log_stats(self, identity):
        """Log stats and confirm success."""
        await self.engine.do_log_stats()

86
        await self.socket.send_multipart(
87
            (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
88
89
90

    async def is_server_ready(self, identity):
        """Notify the client that we are ready."""
91
        await self.socket.send_multipart(
92
            (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
93
94
95

    async def abort(self, identity, request: RPCAbortRequest):
        """Abort request and notify the client of success."""
96
97
98
        try:
            # Abort the request in the llm engine.
            await self.engine.abort(request.request_id)
99
100
101
            result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
        except Exception as e:
            result = e
102
        await self.socket.send_multipart((identity, pickle.dumps(result)))
103
104
105
106
107
108
109
110
111
112
113
114
115

    async def generate(self, identity, generate_request: RPCGenerateRequest):
        try:
            results_generator = self.engine.generate(
                generate_request.inputs,
                sampling_params=generate_request.sampling_params,
                request_id=generate_request.request_id,
                lora_request=generate_request.lora_request,
                trace_headers=generate_request.trace_headers,
                prompt_adapter_request=generate_request.prompt_adapter_request)

            async for request_output in results_generator:
                await self.socket.send_multipart(
116
                    (identity, pickle.dumps(request_output)), copy=False)
117
118

        except Exception as e:
119
120
            await self.socket.send_multipart((identity, pickle.dumps(e)),
                                             copy=False)
121
122
123
124
125

    async def check_health(self, identity):
        try:
            await self.engine.check_health()
            await self.socket.send_multipart(
126
                (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
127

128
        except Exception as e:
129
130
            await self.socket.send_multipart((identity, pickle.dumps(e)),
                                             copy=False)
131

132
133
134
135
136
    async def start_profile(self, identity):
        logger.info("Starting profiler...")
        await self.engine.start_profile()
        logger.info("Profiler started.")

137
        await self.socket.send_multipart((
138
            identity,
139
140
            pickle.dumps(VLLM_RPC_SUCCESS_STR),
        ))
141
142
143
144
145
146

    async def stop_profile(self, identity):
        logger.info("Stopping profiler...")
        await self.engine.stop_profile()
        logger.info("Profiler stopped.")

147
        await self.socket.send_multipart((
148
            identity,
149
150
            pickle.dumps(VLLM_RPC_SUCCESS_STR),
        ))
151

152
    def _make_handler_coro(self, identity,
153
                           message: Frame) -> Coroutine[Any, Any, Never]:
154
155
        """Route the zmq message to the handler coroutine."""

156
        request = cloudpickle.loads(message.buffer)
157
158
159
160
161
162
163
164

        if isinstance(request, RPCGenerateRequest):
            return self.generate(identity, request)

        elif isinstance(request, RPCAbortRequest):
            return self.abort(identity, request)

        elif isinstance(request, RPCUtilityRequest):
165
166
167
168
169
170
171
172
            if request in [
                    RPCUtilityRequest.GET_MODEL_CONFIG,
                    RPCUtilityRequest.GET_PARALLEL_CONFIG,
                    RPCUtilityRequest.GET_DECODING_CONFIG,
                    RPCUtilityRequest.GET_SCHEDULER_CONFIG,
                    RPCUtilityRequest.GET_LORA_CONFIG
            ]:
                return self.get_config(identity, request)
173
174
175
176
            elif request == RPCUtilityRequest.DO_LOG_STATS:
                return self.do_log_stats(identity)
            elif request == RPCUtilityRequest.IS_SERVER_READY:
                return self.is_server_ready(identity)
177
            elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
178
179
180
                return self.check_health(identity)
            elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
                return self.is_tracing_enabled(identity)
181
182
183
184
            elif request == RPCUtilityRequest.START_PROFILE:
                return self.start_profile(identity)
            elif request == RPCUtilityRequest.STOP_PROFILE:
                return self.stop_profile(identity)
185
186
187
188
189
190
191
192
193
194
195
196
            else:
                raise ValueError(f"Unknown RPCUtilityRequest type: {request}")

        else:
            raise ValueError(f"Unknown RPCRequest type: {request}")

    async def run_server_loop(self):
        """Inner RPC Server Loop"""

        running_tasks = set()
        while True:
            # Wait for a request.
197
            identity, message = await self.socket.recv_multipart(copy=False)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

            # Process the request async.
            task = asyncio.create_task(
                self._make_handler_coro(identity, message))

            # We need to keep around a strong reference to the task,
            # to avoid the task disappearing mid-execution as running tasks
            # can be GC'ed. Below is a common "fire-and-forget" tasks
            # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
            running_tasks.add(task)
            task.add_done_callback(running_tasks.discard)


async def run_server(server: AsyncEngineRPCServer):
    # Put the server task into the asyncio loop.
    loop = asyncio.get_running_loop()
    server_task = loop.create_task(server.run_server_loop())

    # Interruption handling.
    def signal_handler() -> None:
        # Kill the server on interrupt / terminate
        server_task.cancel()

    loop.add_signal_handler(signal.SIGINT, signal_handler)
    loop.add_signal_handler(signal.SIGTERM, signal_handler)

    try:
        await server_task
    except asyncio.CancelledError:
        logger.info("vLLM ZMQ RPC Server was interrupted.")
    finally:
        # Clean up all resources.
        server.cleanup()


def run_rpc_server(async_engine_args: AsyncEngineArgs,
234
                   usage_context: UsageContext, rpc_path: str):
235
236
237
238
239
240
241

    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("AsyncEngineRPCServer terminated")

    signal.signal(signal.SIGTERM, signal_handler)

242
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
243
    uvloop.run(run_server(server))