server.py 11.2 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
"""SRT: SGLang Runtime"""
2

Lianmin Zheng's avatar
Lianmin Zheng committed
3
import asyncio
Liangsheng Yin's avatar
Liangsheng Yin committed
4
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
5
import json
6
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
7
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
8
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
9
10
11
import sys
import threading
import time
12
from http import HTTPStatus
13
from typing import Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
14

15
# Fix a bug of Python threading
16
17
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
18
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
21
22
import psutil
import requests
import uvicorn
import uvloop
23
from fastapi import FastAPI, Request
24
from fastapi.responses import JSONResponse, Response, StreamingResponse
Liangsheng Yin's avatar
Liangsheng Yin committed
25

Lianmin Zheng's avatar
Lianmin Zheng committed
26
from sglang.backend.runtime_endpoint import RuntimeEndpoint
Liangsheng Yin's avatar
Liangsheng Yin committed
27
from sglang.srt.constrained import disable_cache
Ying Sheng's avatar
Ying Sheng committed
28
from sglang.srt.hf_transformers_utils import get_tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
29
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
30
from sglang.srt.managers.io_struct import GenerateReqInput
31
32
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
Lianmin Zheng's avatar
Lianmin Zheng committed
33
from sglang.srt.managers.tokenizer_manager import TokenizerManager
34
from sglang.srt.openai_api_adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
35
36
37
38
    load_chat_template_for_openai_api,
    v1_chat_completions,
    v1_completions,
)
39
from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
40
from sglang.srt.utils import (
Liangsheng Yin's avatar
Liangsheng Yin committed
41
42
    API_KEY_HEADER_NAME,
    APIKeyValidatorMiddleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
    allocate_init_ports,
    assert_pkg_version,
45
    enable_show_time_cost,
Lianmin Zheng's avatar
Lianmin Zheng committed
46
)
47
48
from sglang.utils import get_exception_traceback

Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
51

Lianmin Zheng's avatar
Lianmin Zheng committed
52
53
54
55
app = FastAPI()
tokenizer_manager = None


56
57
58
59
60
61
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
64
65
66
67
68
@app.get("/get_model_info")
async def get_model_info():
    result = {
        "model_path": tokenizer_manager.model_path,
    }
    return result

Cody Yu's avatar
Cody Yu committed
69

Liangsheng Yin's avatar
Liangsheng Yin committed
70
71
72
73
74
@app.get("/get_server_args")
async def get_server_args():
    return dataclasses.asdict(tokenizer_manager.server_args)


Liangsheng Yin's avatar
Liangsheng Yin committed
75
76
@app.get("/flush_cache")
async def flush_cache():
77
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
78
    return Response(
79
80
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
81
82
83
84
        status_code=200,
    )


85
async def generate_request(obj: GenerateReqInput, request: Request):
Lianmin Zheng's avatar
Lianmin Zheng committed
86
    if obj.stream:
87

Lianmin Zheng's avatar
Lianmin Zheng committed
88
        async def stream_results():
89
90
91
92
93
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
                    yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            except ValueError as e:
                out = {"error": {"message": str(e)}}
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
96
                yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            yield "data: [DONE]\n\n"

97
98
        return StreamingResponse(stream_results(), media_type="text/event-stream",
                                 background=tokenizer_manager.create_abort_task(obj))
99
100
101
102
103
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
104
105
106
107
            return JSONResponse(
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
108

Ying Sheng's avatar
Ying Sheng committed
109
110
111
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
112

Lianmin Zheng's avatar
Lianmin Zheng committed
113
@app.post("/v1/completions")
114
115
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117


Cody Yu's avatar
Cody Yu committed
118
@app.post("/v1/chat/completions")
119
120
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
121

Lianmin Zheng's avatar
Lianmin Zheng committed
122

Yuanhan Zhang's avatar
Yuanhan Zhang committed
123
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
124
125
    global tokenizer_manager

126
127
128
129
130
    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format="%(message)s",
    )

Lianmin Zheng's avatar
Lianmin Zheng committed
131
132
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
Liangsheng Yin's avatar
Liangsheng Yin committed
133
134
    if server_args.show_time_cost:
        enable_show_time_cost()
135
136
    if server_args.disable_disk_cache:
        disable_cache()
Lianmin Zheng's avatar
Lianmin Zheng committed
137
138
139
140
141
    if server_args.enable_flashinfer:
        assert_pkg_version("flashinfer", "0.0.4")
    if server_args.chat_template:
        # TODO: replace this with huggingface transformers template
        load_chat_template_for_openai_api(server_args.chat_template)
142

Lianmin Zheng's avatar
Lianmin Zheng committed
143
144
    # Allocate ports
    server_args.port, server_args.additional_ports = allocate_init_ports(
145
146
147
148
        server_args.port,
        server_args.additional_ports,
        server_args.tp_size,
        server_args.dp_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
149
    )
150
151
152
153
154
155
156
157
158
159
160
161

    # Init local models port args
    ports = server_args.additional_ports
    tp = server_args.tp_size
    model_port_args = []
    for i in range(server_args.dp_size):
        model_port_args.append(
            ModelPortArgs(
                nccl_port=ports[3 + i * (tp + 1)],
                model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)],
            )
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
162
    port_args = PortArgs(
163
164
165
166
        tokenizer_port=ports[0],
        router_port=ports[1],
        detokenizer_port=ports[2],
        model_port_args=model_port_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
167
168
169
    )

    # Launch processes
Yuanhan Zhang's avatar
Yuanhan Zhang committed
170
    tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
171
172
173
    pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
    pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)

174
175
176
177
    if server_args.dp_size == 1:
        start_process = start_controller_process_single
    else:
        start_process = start_controller_process_multi
Lianmin Zheng's avatar
Lianmin Zheng committed
178
    proc_router = mp.Process(
179
        target=start_process,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
180
        args=(server_args, port_args, pipe_router_writer, model_overide_args),
Lianmin Zheng's avatar
Lianmin Zheng committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    )
    proc_router.start()
    proc_detoken = mp.Process(
        target=start_detokenizer_process,
        args=(
            server_args,
            port_args,
            pipe_detoken_writer,
        ),
    )
    proc_detoken.start()

    # Wait for the model to finish loading
    router_init_state = pipe_router_reader.recv()
    detoken_init_state = pipe_detoken_reader.recv()

    if router_init_state != "init ok" or detoken_init_state != "init ok":
        proc_router.kill()
        proc_detoken.kill()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
200
201
202
203
204
205
206
        print(
            f"Initialization failed. router_init_state: {router_init_state}", flush=True
        )
        print(
            f"Initialization failed. detoken_init_state: {detoken_init_state}",
            flush=True,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
207
208
209
        sys.exit(1)
    assert proc_router.is_alive() and proc_detoken.is_alive()

210
211
212
    if server_args.api_key and server_args.api_key != "":
        app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)

213
    # Send a warmup request
214
    def _wait_and_warmup():
215
        headers = {}
216
        url = server_args.url()
Lianmin Zheng's avatar
Lianmin Zheng committed
217
        if server_args.api_key:
218
219
            headers[API_KEY_HEADER_NAME] = server_args.api_key

Lianmin Zheng's avatar
Lianmin Zheng committed
220
        # Wait until the server is launched
221
222
        for _ in range(120):
            time.sleep(0.5)
223
            try:
224
                requests.get(url + "/get_model_info", timeout=5, headers=headers)
225
                break
226
            except requests.exceptions.RequestException as e:
227
                pass
Lianmin Zheng's avatar
Lianmin Zheng committed
228

Lianmin Zheng's avatar
Lianmin Zheng committed
229
        # Send a warmup request
Cody Yu's avatar
Cody Yu committed
230
        try:
231
232
233
            res = requests.post(
                url + "/generate",
                json={
234
                    "text": "The capital city of France is",
235
236
237
238
239
                    "sampling_params": {
                        "temperature": 0,
                        "max_new_tokens": 16,
                    },
                },
240
                headers=headers,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
241
                timeout=600,
242
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
243
244
            assert res.status_code == 200
        except Exception as e:
245
            if pipe_finish_writer is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
246
247
248
                pipe_finish_writer.send(get_exception_traceback())
            print(f"Initialization failed. warmup error: {e}")
            raise e
Cody Yu's avatar
Cody Yu committed
249
250

        if pipe_finish_writer is not None:
251
            pipe_finish_writer.send("init ok")
Cody Yu's avatar
Cody Yu committed
252

253
254
    t = threading.Thread(target=_wait_and_warmup)
    t.start()
255
256

    # Listen for requests
257
    try:
Lianmin Zheng's avatar
Lianmin Zheng committed
258
259
260
261
262
263
264
265
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
            log_level=server_args.log_level,
            timeout_keep_alive=5,
            loop="uvloop",
        )
266
267
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
268
269
270
271
272


class Runtime:
    def __init__(
        self,
273
        log_level: str = "error",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
274
        model_overide_args: Optional[dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
275
276
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
277
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
278
        """See the arguments in server_args.py::ServerArgs"""
279
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
282

        # Pre-allocate ports
        self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
283
284
285
            self.server_args.port,
            self.server_args.additional_ports,
            self.server_args.tp_size,
286
            self.server_args.dp_size,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
287
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
288

Ying Sheng's avatar
Ying Sheng committed
289
290
291
292
        self.url = self.server_args.url()
        self.generate_url = (
            f"http://{self.server_args.host}:{self.server_args.port}/generate"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
293
294
295

        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
Yuanhan Zhang's avatar
Yuanhan Zhang committed
296
297
298
299
        proc = mp.Process(
            target=launch_server,
            args=(self.server_args, pipe_writer, model_overide_args),
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
300
        proc.start()
301
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
302
303
        self.pid = proc.pid

304
305
306
307
308
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
        if init_state != "init ok":
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
311
312
313
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
314
315
316
317
318

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
319
320
321
322
            try:
                parent = psutil.Process(self.pid)
            except psutil.NoSuchProcess:
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
323
324
325
326
327
328
329
330
            children = parent.children(recursive=True)
            for child in children:
                child.kill()
            psutil.wait_procs(children, timeout=5)
            parent.kill()
            parent.wait(timeout=5)
            self.pid = None

Ying Sheng's avatar
Ying Sheng committed
331
332
333
334
335
336
337
338
339
340
341
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

    async def add_request(
        self,
        prompt: str,
        sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
342
    ):
Ying Sheng's avatar
Ying Sheng committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "stream": True,
        }
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
                        cur = data["text"][pos:]
                        if cur:
                            yield cur
                        pos += len(cur)

Lianmin Zheng's avatar
Lianmin Zheng committed
364
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
365
        self.shutdown()