"docs/pages/kubernetes/installation-guide.md" did not exist on "ad4821c56988146b8067bfc6b3653a099c0676a6"
server.py 8.72 KB
Newer Older
1
2
import asyncio
import signal
3
from typing import Any, Coroutine, Union
4
5

import cloudpickle
6
import uvloop
7
8
9
10
11
import zmq
import zmq.asyncio
from typing_extensions import Never

from vllm import AsyncEngineArgs, AsyncLLMEngine
12
13
14
15
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
                         ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
                                         VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
16
17
18
19
20
21
                                         RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext

logger = init_logger(__name__)

22
23
24
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
                    SchedulerConfig, LoRAConfig]

25
26
27
28

class AsyncEngineRPCServer:

    def __init__(self, async_engine_args: AsyncEngineArgs,
29
                 usage_context: UsageContext, rpc_path: str):
30
        # Initialize engine first.
31
32
        self.engine = AsyncLLMEngine.from_engine_args(
            async_engine_args, usage_context=usage_context)
33
34
35
36

        # Initialize context.
        self.context = zmq.asyncio.Context()

37
38
39
40
        # Init socket.
        self.socket = self.context.socket(zmq.constants.DEALER)
        self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
        self.socket.connect(rpc_path)
41
42
43
44
45

    def cleanup(self):
        """Cleanup all resources."""
        self.socket.close()
        self.context.destroy()
46
        self.engine.shutdown_background_loop()
47
        # Clear the engine reference so that it can be GC'ed.
48
        del self.engine
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    async def get_config(self, identity, request):
        try:
            config: CONFIG_TYPE
            if request == RPCUtilityRequest.GET_MODEL_CONFIG:
                config = await self.engine.get_model_config()
            elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
                config = await self.engine.get_decoding_config()
            elif request == RPCUtilityRequest.GET_LORA_CONFIG:
                config = await self.engine.get_lora_config()
            elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
                config = await self.engine.get_scheduler_config()
            elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
                config = await self.engine.get_parallel_config()
            else:
                raise ValueError("Unknown Config Request: %s", request)
65

66
67
            await self.socket.send_multipart(
                [identity, cloudpickle.dumps(config)])
68

69
70
        except Exception as e:
            await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
71
72
73
74
75
76
77
78
79
80
81
82

    async def is_tracing_enabled(self, identity):
        """Send the is_tracing_enabled flag"""
        tracing_flag = await self.engine.is_tracing_enabled()

        await self.socket.send_multipart(
            [identity, cloudpickle.dumps(tracing_flag)])

    async def do_log_stats(self, identity):
        """Log stats and confirm success."""
        await self.engine.do_log_stats()

83
84
        await self.socket.send_multipart(
            [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
85
86
87

    async def is_server_ready(self, identity):
        """Notify the client that we are ready."""
88
89
        await self.socket.send_multipart(
            [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])
90
91
92

    async def abort(self, identity, request: RPCAbortRequest):
        """Abort request and notify the client of success."""
93
94
95
        try:
            # Abort the request in the llm engine.
            await self.engine.abort(request.request_id)
96
97
98
99
            result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
        except Exception as e:
            result = e
        await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

    async def generate(self, identity, generate_request: RPCGenerateRequest):
        try:
            results_generator = self.engine.generate(
                generate_request.inputs,
                sampling_params=generate_request.sampling_params,
                request_id=generate_request.request_id,
                lora_request=generate_request.lora_request,
                trace_headers=generate_request.trace_headers,
                prompt_adapter_request=generate_request.prompt_adapter_request)

            async for request_output in results_generator:
                await self.socket.send_multipart(
                    [identity, cloudpickle.dumps(request_output)])

        except Exception as e:
            await self.socket.send_multipart([identity, cloudpickle.dumps(e)])

    async def check_health(self, identity):
        try:
            await self.engine.check_health()
            await self.socket.send_multipart(
122
123
                [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)])

124
125
126
        except Exception as e:
            await self.socket.send_multipart([identity, cloudpickle.dumps(e)])

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    async def start_profile(self, identity):
        logger.info("Starting profiler...")
        await self.engine.start_profile()
        logger.info("Profiler started.")

        await self.socket.send_multipart([
            identity,
            cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
        ])

    async def stop_profile(self, identity):
        logger.info("Stopping profiler...")
        await self.engine.stop_profile()
        logger.info("Profiler stopped.")

        await self.socket.send_multipart([
            identity,
            cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
        ])

147
148
149
150
151
152
153
154
155
156
157
158
159
    def _make_handler_coro(self, identity,
                           message) -> Coroutine[Any, Any, Never]:
        """Route the zmq message to the handler coroutine."""

        request = cloudpickle.loads(message)

        if isinstance(request, RPCGenerateRequest):
            return self.generate(identity, request)

        elif isinstance(request, RPCAbortRequest):
            return self.abort(identity, request)

        elif isinstance(request, RPCUtilityRequest):
160
161
162
163
164
165
166
167
            if request in [
                    RPCUtilityRequest.GET_MODEL_CONFIG,
                    RPCUtilityRequest.GET_PARALLEL_CONFIG,
                    RPCUtilityRequest.GET_DECODING_CONFIG,
                    RPCUtilityRequest.GET_SCHEDULER_CONFIG,
                    RPCUtilityRequest.GET_LORA_CONFIG
            ]:
                return self.get_config(identity, request)
168
169
170
171
            elif request == RPCUtilityRequest.DO_LOG_STATS:
                return self.do_log_stats(identity)
            elif request == RPCUtilityRequest.IS_SERVER_READY:
                return self.is_server_ready(identity)
172
            elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
173
174
175
                return self.check_health(identity)
            elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
                return self.is_tracing_enabled(identity)
176
177
178
179
            elif request == RPCUtilityRequest.START_PROFILE:
                return self.start_profile(identity)
            elif request == RPCUtilityRequest.STOP_PROFILE:
                return self.stop_profile(identity)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
            else:
                raise ValueError(f"Unknown RPCUtilityRequest type: {request}")

        else:
            raise ValueError(f"Unknown RPCRequest type: {request}")

    async def run_server_loop(self):
        """Inner RPC Server Loop"""

        running_tasks = set()
        while True:
            # Wait for a request.
            identity, message = await self.socket.recv_multipart()

            # Process the request async.
            task = asyncio.create_task(
                self._make_handler_coro(identity, message))

            # We need to keep around a strong reference to the task,
            # to avoid the task disappearing mid-execution as running tasks
            # can be GC'ed. Below is a common "fire-and-forget" tasks
            # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
            running_tasks.add(task)
            task.add_done_callback(running_tasks.discard)


async def run_server(server: AsyncEngineRPCServer):
    # Put the server task into the asyncio loop.
    loop = asyncio.get_running_loop()
    server_task = loop.create_task(server.run_server_loop())

    # Interruption handling.
    def signal_handler() -> None:
        # Kill the server on interrupt / terminate
        server_task.cancel()

    loop.add_signal_handler(signal.SIGINT, signal_handler)
    loop.add_signal_handler(signal.SIGTERM, signal_handler)

    try:
        await server_task
    except asyncio.CancelledError:
        logger.info("vLLM ZMQ RPC Server was interrupted.")
    finally:
        # Clean up all resources.
        server.cleanup()


def run_rpc_server(async_engine_args: AsyncEngineArgs,
229
230
                   usage_context: UsageContext, rpc_path: str):
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
231
    uvloop.run(run_server(server))