"benchmarks/benchmark_topk_topp.py" did not exist on "dc5fa77a4eb6680339cb77abe713fb22d7795560"
server.py 9 KB
Newer Older
1
import asyncio
2
import pickle
3
import signal
4
from typing import Any, Coroutine, Union
5
6

import cloudpickle
7
import uvloop
8
9
10
import zmq
import zmq.asyncio
from typing_extensions import Never
11
12
from zmq import Frame  # type: ignore[attr-defined]
from zmq.asyncio import Socket
13
14

from vllm import AsyncEngineArgs, AsyncLLMEngine
15
16
17
18
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
                                         VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
19
20
21
22
23
24
                                         RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext

logger = init_logger(__name__)

25
26
27
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
                    SchedulerConfig, LoRAConfig]

28
29
30
31

class AsyncEngineRPCServer:

    def __init__(self, async_engine_args: AsyncEngineArgs,
32
                 usage_context: UsageContext, rpc_path: str):
33
        # Initialize engine first.
34
35
        self.engine = AsyncLLMEngine.from_engine_args(
            async_engine_args, usage_context=usage_context)
36
37
38
39

        # Initialize context.
        self.context = zmq.asyncio.Context()

40
        # Init socket.
41
        self.socket: Socket = self.context.socket(zmq.constants.DEALER)
42
43
        self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
        self.socket.connect(rpc_path)
44
45
46
47
48

    def cleanup(self):
        """Cleanup all resources."""
        self.socket.close()
        self.context.destroy()
49
        self.engine.shutdown_background_loop()
50
        # Clear the engine reference so that it can be GC'ed.
51
        del self.engine
52

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    async def get_config(self, identity, request):
        try:
            config: CONFIG_TYPE
            if request == RPCUtilityRequest.GET_MODEL_CONFIG:
                config = await self.engine.get_model_config()
            elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
                config = await self.engine.get_decoding_config()
            elif request == RPCUtilityRequest.GET_LORA_CONFIG:
                config = await self.engine.get_lora_config()
            elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
                config = await self.engine.get_scheduler_config()
            elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
                config = await self.engine.get_parallel_config()
            else:
                raise ValueError("Unknown Config Request: %s", request)
68

69
70
            await self.socket.send_multipart((identity, pickle.dumps(config)),
                                             copy=False)
71

72
        except Exception as e:
73
74
            await self.socket.send_multipart((identity, pickle.dumps(e)),
                                             copy=False)
75
76
77
78
79
80

    async def is_tracing_enabled(self, identity):
        """Send the is_tracing_enabled flag"""
        tracing_flag = await self.engine.is_tracing_enabled()

        await self.socket.send_multipart(
81
            (identity, pickle.dumps(tracing_flag)))
82
83
84
85
86

    async def do_log_stats(self, identity):
        """Log stats and confirm success."""
        await self.engine.do_log_stats()

87
        await self.socket.send_multipart(
88
            (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
89
90
91

    async def is_server_ready(self, identity):
        """Notify the client that we are ready."""
92
        await self.socket.send_multipart(
93
            (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
94
95
96

    async def abort(self, identity, request: RPCAbortRequest):
        """Abort request and notify the client of success."""
97
98
99
        try:
            # Abort the request in the llm engine.
            await self.engine.abort(request.request_id)
100
101
102
            result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
        except Exception as e:
            result = e
103
        await self.socket.send_multipart((identity, pickle.dumps(result)))
104
105
106
107
108
109
110
111
112
113
114
115
116

    async def generate(self, identity, generate_request: RPCGenerateRequest):
        try:
            results_generator = self.engine.generate(
                generate_request.inputs,
                sampling_params=generate_request.sampling_params,
                request_id=generate_request.request_id,
                lora_request=generate_request.lora_request,
                trace_headers=generate_request.trace_headers,
                prompt_adapter_request=generate_request.prompt_adapter_request)

            async for request_output in results_generator:
                await self.socket.send_multipart(
117
                    (identity, pickle.dumps(request_output)), copy=False)
118
119

        except Exception as e:
120
121
            await self.socket.send_multipart((identity, pickle.dumps(e)),
                                             copy=False)
122
123
124
125
126

    async def check_health(self, identity):
        try:
            await self.engine.check_health()
            await self.socket.send_multipart(
127
                (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
128

129
        except Exception as e:
130
131
            await self.socket.send_multipart((identity, pickle.dumps(e)),
                                             copy=False)
132

133
134
135
136
137
    async def start_profile(self, identity):
        logger.info("Starting profiler...")
        await self.engine.start_profile()
        logger.info("Profiler started.")

138
        await self.socket.send_multipart((
139
            identity,
140
141
            pickle.dumps(VLLM_RPC_SUCCESS_STR),
        ))
142
143
144
145
146
147

    async def stop_profile(self, identity):
        logger.info("Stopping profiler...")
        await self.engine.stop_profile()
        logger.info("Profiler stopped.")

148
        await self.socket.send_multipart((
149
            identity,
150
151
            pickle.dumps(VLLM_RPC_SUCCESS_STR),
        ))
152

153
    def _make_handler_coro(self, identity,
154
                           message: Frame) -> Coroutine[Any, Any, Never]:
155
156
        """Route the zmq message to the handler coroutine."""

157
        request = cloudpickle.loads(message.buffer)
158
159
160
161
162
163
164
165

        if isinstance(request, RPCGenerateRequest):
            return self.generate(identity, request)

        elif isinstance(request, RPCAbortRequest):
            return self.abort(identity, request)

        elif isinstance(request, RPCUtilityRequest):
166
167
168
169
170
171
172
173
            if request in [
                    RPCUtilityRequest.GET_MODEL_CONFIG,
                    RPCUtilityRequest.GET_PARALLEL_CONFIG,
                    RPCUtilityRequest.GET_DECODING_CONFIG,
                    RPCUtilityRequest.GET_SCHEDULER_CONFIG,
                    RPCUtilityRequest.GET_LORA_CONFIG
            ]:
                return self.get_config(identity, request)
174
175
176
177
            elif request == RPCUtilityRequest.DO_LOG_STATS:
                return self.do_log_stats(identity)
            elif request == RPCUtilityRequest.IS_SERVER_READY:
                return self.is_server_ready(identity)
178
            elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
179
180
181
                return self.check_health(identity)
            elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
                return self.is_tracing_enabled(identity)
182
183
184
185
            elif request == RPCUtilityRequest.START_PROFILE:
                return self.start_profile(identity)
            elif request == RPCUtilityRequest.STOP_PROFILE:
                return self.stop_profile(identity)
186
187
188
189
190
191
192
193
194
195
196
197
            else:
                raise ValueError(f"Unknown RPCUtilityRequest type: {request}")

        else:
            raise ValueError(f"Unknown RPCRequest type: {request}")

    async def run_server_loop(self):
        """Inner RPC Server Loop"""

        running_tasks = set()
        while True:
            # Wait for a request.
198
            identity, message = await self.socket.recv_multipart(copy=False)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

            # Process the request async.
            task = asyncio.create_task(
                self._make_handler_coro(identity, message))

            # We need to keep around a strong reference to the task,
            # to avoid the task disappearing mid-execution as running tasks
            # can be GC'ed. Below is a common "fire-and-forget" tasks
            # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
            running_tasks.add(task)
            task.add_done_callback(running_tasks.discard)


async def run_server(server: AsyncEngineRPCServer):
    # Put the server task into the asyncio loop.
    loop = asyncio.get_running_loop()
    server_task = loop.create_task(server.run_server_loop())

    # Interruption handling.
    def signal_handler() -> None:
        # Kill the server on interrupt / terminate
        server_task.cancel()

    loop.add_signal_handler(signal.SIGINT, signal_handler)
    loop.add_signal_handler(signal.SIGTERM, signal_handler)

    try:
        await server_task
    except asyncio.CancelledError:
        logger.info("vLLM ZMQ RPC Server was interrupted.")
    finally:
        # Clean up all resources.
        server.cleanup()


def run_rpc_server(async_engine_args: AsyncEngineArgs,
235
236
                   usage_context: UsageContext, rpc_path: str):
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
237
    uvloop.run(run_server(server))