server.py 18.9 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
"""SRT: SGLang Runtime"""
import asyncio
import json
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
5
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
6
7
8
import sys
import threading
import time
9
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
10
11
12
13

# Fix a Python bug
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
14
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
15
import psutil
16
import pydantic
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
19
import requests
import uvicorn
import uvloop
20
from fastapi import FastAPI, HTTPException, Request
21
from fastapi.responses import Response, StreamingResponse
22
from pydantic import BaseModel
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from sglang.backend.runtime_endpoint import RuntimeEndpoint
Cody Yu's avatar
Cody Yu committed
24
25
26
27
28
29
30
from sglang.srt.conversation import (
    Conversation,
    SeparatorStyle,
    chat_template_exists,
    generate_chat_conv,
    register_conv_template,
)
Ying Sheng's avatar
Ying Sheng committed
31
from sglang.srt.hf_transformers_utils import get_tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
34
from sglang.srt.managers.openai_protocol import (
Cody Yu's avatar
Cody Yu committed
35
36
37
38
39
40
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse,
    ChatMessage,
41
42
43
44
45
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
Cody Yu's avatar
Cody Yu committed
46
47
    DeltaMessage,
    UsageInfo,
48
)
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
52
from sglang.srt.utils import alloc_usable_network_port, handle_port_init
Lianmin Zheng's avatar
Lianmin Zheng committed
53
54
55
56
57
58

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())


app = FastAPI()
tokenizer_manager = None
Cody Yu's avatar
Cody Yu committed
59
chat_template_name = None
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61


62
63
64
65
66
67
68
69
70
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1

def jsonify_pydantic_model(obj: BaseModel):
    if IS_PYDANTIC_1:
        return obj.json(ensure_ascii=False)
    return obj.model_dump_json()


71
72
73
74
75
76
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
79
80
81
82
83
@app.get("/get_model_info")
async def get_model_info():
    result = {
        "model_path": tokenizer_manager.model_path,
    }
    return result

Cody Yu's avatar
Cody Yu committed
84

Liangsheng Yin's avatar
Liangsheng Yin committed
85
86
87
88
@app.get("/flush_cache")
async def flush_cache():
    await tokenizer_manager.flush_cache()
    return Response(
89
90
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
91
92
93
94
        status_code=200,
    )


95
96
97
98
async def stream_generator(obj):
    async for out in tokenizer_manager.generate_request(obj):
        yield out

Lianmin Zheng's avatar
Lianmin Zheng committed
99
100
101
102
103
104
105
106

@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
    obj.post_init()

    if obj.stream:

        async def stream_results():
107
108
109
            async for out in stream_generator(obj):
                yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            yield "data: [DONE]\n\n"
Cody Yu's avatar
Cody Yu committed
110

Lianmin Zheng's avatar
Lianmin Zheng committed
111
        return StreamingResponse(stream_results(), media_type="text/event-stream")
112
113
114

    ret = await tokenizer_manager.generate_request(obj).__anext__()
    return ret
Lianmin Zheng's avatar
Lianmin Zheng committed
115
116
117


@app.post("/v1/completions")
118
119
120
121
122
123
124
125
126
async def v1_completions(raw_request: Request):
    request_json = await raw_request.json()
    request = CompletionRequest(**request_json)

    # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
    assert request.n == 1

    adapted_request = GenerateReqInput(
        text=request.prompt,
Lianmin Zheng's avatar
Lianmin Zheng committed
127
        sampling_params={
128
129
130
131
132
133
            "temperature": request.temperature,
            "max_new_tokens": request.max_tokens,
            "stop": request.stop,
            "top_p": request.top_p,
            "presence_penalty": request.presence_penalty,
            "frequency_penalty": request.frequency_penalty,
Lianmin Zheng's avatar
Lianmin Zheng committed
134
        },
135
        stream=request.stream,
Lianmin Zheng's avatar
Lianmin Zheng committed
136
    )
137
138
139
    adapted_request.post_init()

    if adapted_request.stream:
Cody Yu's avatar
Cody Yu committed
140

141
142
143
144
        async def gnerate_stream_resp():
            stream_buffer = ""
            async for content in stream_generator(adapted_request):
                text = content["text"]
Cody Yu's avatar
Cody Yu committed
145
146
147
148
                prompt_tokens = content["meta_info"]["prompt_tokens"]
                completion_tokens = content["meta_info"]["completion_tokens"]

                delta = text[len(stream_buffer) :]
149
150
151
152
153
154
155
156
157
158
159
160
                stream_buffer = text
                choice_data = CompletionResponseStreamChoice(
                    index=0,
                    text=delta,
                    logprobs=None,
                    finish_reason=None,
                )
                chunk = CompletionStreamResponse(
                    id=content["meta_info"]["id"],
                    object="text_completion",
                    choices=[choice_data],
                    model=request.model,
Cody Yu's avatar
Cody Yu committed
161
162
163
164
165
                    usage=UsageInfo(
                        prompt_tokens=prompt_tokens,
                        completion_tokens=completion_tokens,
                        total_tokens=prompt_tokens + completion_tokens,
                    ),
166
                )
167
                yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
Cody Yu's avatar
Cody Yu committed
168
            yield "data: [DONE]\n\n"
169
170
171
172
173
174
175
176
177
178

        return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")

    # Non-streaming response.
    ret = await generate_request(adapted_request)

    choice_data = CompletionResponseChoice(
        index=0,
        text=ret["text"],
        logprobs=None,
Cody Yu's avatar
Cody Yu committed
179
        finish_reason=None,  # TODO(comaniac): Add finish reason.
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    )

    prompt_tokens = ret["meta_info"]["prompt_tokens"]
    completion_tokens = ret["meta_info"]["completion_tokens"]
    response = CompletionResponse(
        id=ret["meta_info"]["id"],
        model=request.model,
        choices=[choice_data],
        usage=UsageInfo(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
        ),
    )
    return response
Lianmin Zheng's avatar
Lianmin Zheng committed
195
196


Cody Yu's avatar
Cody Yu committed
197
198
199
200
201
202
203
204
@app.post("/v1/chat/completions")
async def v1_chat_completions(raw_request: Request):
    request_json = await raw_request.json()
    request = ChatCompletionRequest(**request_json)

    # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
    assert request.n == 1

205
206
207
208
209
    # Prep the data needed for the underlying GenerateReqInput:
    #  - prompt: The full prompt string.
    #  - stop: Custom stop tokens.
    #  - image_data: None or a list of image strings (URLs or base64 strings).
    #    None skips any image processing in GenerateReqInput.
Cody Yu's avatar
Cody Yu committed
210
211
212
    if not isinstance(request.messages, str):
        # Apply chat template and its stop strings.
        if chat_template_name is None:
213
214
215
216
217
218
219
220
            # This flow doesn't support the full OpenAI spec.  Verify messages
            # has the right type before proceeding:
            for m in request.messages:
                if not isinstance(m.content, str):
                    raise HTTPException(
                        status_code=503,
                        detail="Structured content requests not supported with HuggingFace Chat Templates.  Make sure the server specifies a sglang chat template.",
                    )
Cody Yu's avatar
Cody Yu committed
221
222
223
224
            prompt = tokenizer_manager.tokenizer.apply_chat_template(
                request.messages, tokenize=False, add_generation_prompt=True
            )
            stop = request.stop
225
            image_data = None
Cody Yu's avatar
Cody Yu committed
226
227
228
        else:
            conv = generate_chat_conv(request, chat_template_name)
            prompt = conv.get_prompt()
229
            image_data = conv.image_data
Cody Yu's avatar
Cody Yu committed
230
231
232
233
234
235
236
237
238
239
            stop = conv.stop_str or []
            if request.stop:
                if isinstance(request.stop, str):
                    stop.append(request.stop)
                else:
                    stop.extend(request.stop)
    else:
        # Use the raw prompt and stop strings if the messages is already a string.
        prompt = request.messages
        stop = request.stop
240
        image_data = None
Cody Yu's avatar
Cody Yu committed
241
242
243

    adapted_request = GenerateReqInput(
        text=prompt,
244
        image_data=image_data,
Cody Yu's avatar
Cody Yu committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        sampling_params={
            "temperature": request.temperature,
            "max_new_tokens": request.max_tokens,
            "stop": stop,
            "top_p": request.top_p,
            "presence_penalty": request.presence_penalty,
            "frequency_penalty": request.frequency_penalty,
        },
        stream=request.stream,
    )
    adapted_request.post_init()

    if adapted_request.stream:

        async def gnerate_stream_resp():
            is_first = True

            stream_buffer = ""
            async for content in stream_generator(adapted_request):
                if is_first:
                    # First chunk with role
                    is_first = False
                    choice_data = ChatCompletionResponseStreamChoice(
                        index=0,
                        delta=DeltaMessage(role="assistant"),
                        finish_reason=None,
                    )
                    chunk = ChatCompletionStreamResponse(
273
274
275
                        id=content["meta_info"]["id"],
                        choices=[choice_data],
                        model=request.model,
Cody Yu's avatar
Cody Yu committed
276
                    )
277
                    yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
Cody Yu's avatar
Cody Yu committed
278
279
280
281
282
283
284
285

                text = content["text"]
                delta = text[len(stream_buffer) :]
                stream_buffer = text
                choice_data = ChatCompletionResponseStreamChoice(
                    index=0, delta=DeltaMessage(content=delta), finish_reason=None
                )
                chunk = ChatCompletionStreamResponse(
286
287
288
                    id=content["meta_info"]["id"],
                    choices=[choice_data],
                    model=request.model,
Cody Yu's avatar
Cody Yu committed
289
                )
290
                yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
Cody Yu's avatar
Cody Yu committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            yield "data: [DONE]\n\n"

        return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")

    # Non-streaming response.
    ret = await generate_request(adapted_request)
    prompt_tokens = ret["meta_info"]["prompt_tokens"]
    completion_tokens = ret["meta_info"]["completion_tokens"]
    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=ChatMessage(role="assistant", content=ret["text"]),
        finish_reason=None,  # TODO(comaniac): Add finish reason.
    )
    response = ChatCompletionResponse(
        id=ret["meta_info"]["id"],
        model=request.model,
        choices=[choice_data],
        usage=UsageInfo(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
        ),
    )
    return response


Lianmin Zheng's avatar
Lianmin Zheng committed
317
318
def launch_server(server_args, pipe_finish_writer):
    global tokenizer_manager
Cody Yu's avatar
Cody Yu committed
319
    global chat_template_name
Lianmin Zheng's avatar
Lianmin Zheng committed
320

321
322
323
    # Handle ports
    server_args.port, server_args.additional_ports = handle_port_init(
        server_args.port, server_args.additional_ports, server_args.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
324
    )
325

Lianmin Zheng's avatar
Lianmin Zheng committed
326
    port_args = PortArgs(
327
328
329
330
331
        tokenizer_port=server_args.additional_ports[0],
        router_port=server_args.additional_ports[1],
        detokenizer_port=server_args.additional_ports[2],
        nccl_port=server_args.additional_ports[3],
        model_rpc_ports=server_args.additional_ports[4:],
Lianmin Zheng's avatar
Lianmin Zheng committed
332
333
    )

Cody Yu's avatar
Cody Yu committed
334
335
    # Load chat template if needed
    if server_args.chat_template is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
336
        print(f"Use chat template: {server_args.chat_template}")
Cody Yu's avatar
Cody Yu committed
337
338
339
340
341
342
343
344
345
346
347
        if not chat_template_exists(server_args.chat_template):
            if not os.path.exists(server_args.chat_template):
                raise RuntimeError(
                    f"Chat template {server_args.chat_template} is not a built-in template name "
                    "or a valid chat template file path."
                )
            with open(server_args.chat_template, "r") as filep:
                template = json.load(filep)
                try:
                    sep_style = SeparatorStyle[template["sep_style"]]
                except KeyError:
348
349
350
                    raise ValueError(
                        f"Unknown separator style: {template['sep_style']}"
                    ) from None
Cody Yu's avatar
Cody Yu committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
                register_conv_template(
                    Conversation(
                        name=template["name"],
                        system_template=template["system"] + "\n{system_message}",
                        system_message=template.get("system_message", ""),
                        roles=(template["user"], template["assistant"]),
                        sep_style=sep_style,
                        sep=template.get("sep", "\n"),
                        stop_str=template["stop_str"],
                    ),
                    override=True,
                )
            chat_template_name = template["name"]
        else:
            chat_template_name = server_args.chat_template

Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    # Launch processes
    tokenizer_manager = TokenizerManager(server_args, port_args)
    pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
    pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)

    proc_router = mp.Process(
        target=start_router_process,
        args=(
            server_args,
            port_args,
            pipe_router_writer,
        ),
    )
    proc_router.start()
    proc_detoken = mp.Process(
        target=start_detokenizer_process,
        args=(
            server_args,
            port_args,
            pipe_detoken_writer,
        ),
    )
    proc_detoken.start()

    # Wait for the model to finish loading
    router_init_state = pipe_router_reader.recv()
    detoken_init_state = pipe_detoken_reader.recv()

    if router_init_state != "init ok" or detoken_init_state != "init ok":
        proc_router.kill()
        proc_detoken.kill()
        print("router init state:", router_init_state)
        print("detoken init state:", detoken_init_state)
        sys.exit(1)

    assert proc_router.is_alive() and proc_detoken.is_alive()

Cody Yu's avatar
Cody Yu committed
404
    def _launch_server():
Lianmin Zheng's avatar
Lianmin Zheng committed
405
406
407
408
409
410
411
412
413
414
        # Launch api server
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
            log_level=server_args.log_level,
            timeout_keep_alive=5,
            loop="uvloop",
        )

Cody Yu's avatar
Cody Yu committed
415
    t = threading.Thread(target=_launch_server)
Lianmin Zheng's avatar
Lianmin Zheng committed
416
417
    t.start()

Cody Yu's avatar
Cody Yu committed
418
419
420
421
422
423
424
425
426
427
428
    url = server_args.url()
    for _ in range(60):
        time.sleep(1)
        try:
            requests.get(url + "/get_model_info", timeout=5)
            break
        except requests.exceptions.RequestException as e:
            pass
    else:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(str(e))
Lianmin Zheng's avatar
Lianmin Zheng committed
429
        else:
Cody Yu's avatar
Cody Yu committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            print(e, flush=True)
        return

    # Warmup
    try:
        print("Warmup...", flush=True)
        res = requests.post(
            url + "/generate",
            json={
                "text": "Say this is a warmup request.",
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": 16,
                },
            },
            timeout=60,
        )
        print(f"Warmup done. model response: {res.json()['text']}")
    except requests.exceptions.RequestException as e:
        if pipe_finish_writer is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
450
            pipe_finish_writer.send(str(e))
Cody Yu's avatar
Cody Yu committed
451
452
453
454
455
456
        else:
            print(e, flush=True)
        return

    if pipe_finish_writer is not None:
        pipe_finish_writer.send("init ok")
Lianmin Zheng's avatar
Lianmin Zheng committed
457
458
459
460
461
462
463
464
465
466


class Runtime:
    def __init__(
        self,
        model_path: str,
        tokenizer_path: Optional[str] = None,
        load_format: str = "auto",
        tokenizer_mode: str = "auto",
        trust_remote_code: bool = True,
467
468
        mem_fraction_static: float = ServerArgs.mem_fraction_static,
        max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
469
470
471
472
        tp_size: int = 1,
        model_mode: List[str] = (),
        schedule_heuristic: str = "lpm",
        random_seed: int = 42,
473
        log_level: str = "error",
474
475
        port: Optional[int] = None,
        additional_ports: Optional[Union[List[int], int]] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
476
477
    ):
        host = "127.0.0.1"
Lianmin Zheng's avatar
Lianmin Zheng committed
478
        port, additional_ports = handle_port_init(port, additional_ports, tp_size)
Ying Sheng's avatar
Ying Sheng committed
479
        self.server_args = ServerArgs(
Lianmin Zheng's avatar
Lianmin Zheng committed
480
481
482
483
            model_path=model_path,
            tokenizer_path=tokenizer_path,
            host=host,
            port=port,
484
            additional_ports=additional_ports,
Lianmin Zheng's avatar
Lianmin Zheng committed
485
486
487
488
            load_format=load_format,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
            mem_fraction_static=mem_fraction_static,
489
            max_prefill_num_token=max_prefill_num_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
490
491
492
493
494
495
            tp_size=tp_size,
            model_mode=model_mode,
            schedule_heuristic=schedule_heuristic,
            random_seed=random_seed,
            log_level=log_level,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
496

Ying Sheng's avatar
Ying Sheng committed
497
498
499
500
        self.url = self.server_args.url()
        self.generate_url = (
            f"http://{self.server_args.host}:{self.server_args.port}/generate"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
501
502
503

        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
Ying Sheng's avatar
Ying Sheng committed
504
        proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
Lianmin Zheng's avatar
Lianmin Zheng committed
505
        proc.start()
506
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
507
508
        self.pid = proc.pid

509
510
511
512
513
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

Lianmin Zheng's avatar
Lianmin Zheng committed
514
515
        if init_state != "init ok":
            self.shutdown()
516
            raise RuntimeError("Launch failed. Please see the error messages above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
517
518
519
520
521

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
522
523
524
525
            try:
                parent = psutil.Process(self.pid)
            except psutil.NoSuchProcess:
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
526
527
528
529
530
531
532
533
            children = parent.children(recursive=True)
            for child in children:
                child.kill()
            psutil.wait_procs(children, timeout=5)
            parent.kill()
            parent.wait(timeout=5)
            self.pid = None

Ying Sheng's avatar
Ying Sheng committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

    async def add_request(
        self,
        prompt: str,
        sampling_params,
    ) -> None:
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "stream": True,
        }

        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
                        cur = data["text"][pos:]
                        if cur:
                            yield cur
                        pos += len(cur)

Lianmin Zheng's avatar
Lianmin Zheng committed
568
569
    def __del__(self):
        self.shutdown()