server.py 23.5 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
"""SRT: SGLang Runtime"""
2

Lianmin Zheng's avatar
Lianmin Zheng committed
3
import asyncio
Liangsheng Yin's avatar
Liangsheng Yin committed
4
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
5
6
import json
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
7
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
10
import sys
import threading
import time
11
from typing import List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
12
13
14
15

# Fix a Python bug
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
16
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
17
import psutil
18
import pydantic
Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
21
import requests
import uvicorn
import uvloop
22
from fastapi import FastAPI, HTTPException, Request
23
from fastapi.responses import Response, StreamingResponse
24
from pydantic import BaseModel
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from sglang.backend.runtime_endpoint import RuntimeEndpoint
Liangsheng Yin's avatar
Liangsheng Yin committed
26
from sglang.srt.constrained import disable_cache
Cody Yu's avatar
Cody Yu committed
27
28
29
30
31
32
33
from sglang.srt.conversation import (
    Conversation,
    SeparatorStyle,
    chat_template_exists,
    generate_chat_conv,
    register_conv_template,
)
Ying Sheng's avatar
Ying Sheng committed
34
from sglang.srt.hf_transformers_utils import get_tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
35
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
Cody Yu's avatar
Cody Yu committed
36
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
37
from sglang.srt.managers.openai_protocol import (
Cody Yu's avatar
Cody Yu committed
38
39
40
41
42
43
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    ChatCompletionResponseStreamChoice,
    ChatCompletionStreamResponse,
    ChatMessage,
44
45
46
47
48
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionResponseStreamChoice,
    CompletionStreamResponse,
Cody Yu's avatar
Cody Yu committed
49
    DeltaMessage,
Cody Yu's avatar
Cody Yu committed
50
    LogProbs,
Cody Yu's avatar
Cody Yu committed
51
    UsageInfo,
52
)
Lianmin Zheng's avatar
Lianmin Zheng committed
53
54
55
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
56
from sglang.srt.utils import handle_port_init
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

62
63
API_KEY_HEADER_NAME = "X-API-Key"

Lianmin Zheng's avatar
Lianmin Zheng committed
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, api_key: str):
        super().__init__(app)
        self.api_key = api_key

    async def dispatch(self, request: Request, call_next):
        # extract API key from the request headers
        api_key_header = request.headers.get(API_KEY_HEADER_NAME)
        if not api_key_header or api_key_header != self.api_key:
            return JSONResponse(
                status_code=403,
                content={"detail": "Invalid API Key"},
            )
        response = await call_next(request)
        return response
Lianmin Zheng's avatar
Lianmin Zheng committed
80

Lianmin Zheng's avatar
Lianmin Zheng committed
81

Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
app = FastAPI()
tokenizer_manager = None
Cody Yu's avatar
Cody Yu committed
84
chat_template_name = None
Lianmin Zheng's avatar
Lianmin Zheng committed
85
86


87
88
89
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1

90

91
92
93
94
95
96
def jsonify_pydantic_model(obj: BaseModel):
    if IS_PYDANTIC_1:
        return obj.json(ensure_ascii=False)
    return obj.model_dump_json()


97
98
99
100
101
102
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Lianmin Zheng's avatar
Lianmin Zheng committed
103
104
105
106
107
108
109
@app.get("/get_model_info")
async def get_model_info():
    result = {
        "model_path": tokenizer_manager.model_path,
    }
    return result

Cody Yu's avatar
Cody Yu committed
110

Liangsheng Yin's avatar
Liangsheng Yin committed
111
112
113
114
115
@app.get("/get_server_args")
async def get_server_args():
    return dataclasses.asdict(tokenizer_manager.server_args)


Liangsheng Yin's avatar
Liangsheng Yin committed
116
117
118
119
@app.get("/flush_cache")
async def flush_cache():
    await tokenizer_manager.flush_cache()
    return Response(
120
121
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
122
123
124
125
        status_code=200,
    )


126
127
128
129
130
131
132
async def detokenize_logprob_tokens(token_logprobs):
    token_ids = [tid for tid, _ in token_logprobs]
    token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
    return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]


async def stream_generator(obj: GenerateReqInput):
133
    async for out in tokenizer_manager.generate_request(obj):
134
135
136
137
        if obj.return_logprob and obj.return_text_in_logprobs:
            out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
                out["meta_info"]["token_logprob"]
            )
138
139
        yield out

Lianmin Zheng's avatar
Lianmin Zheng committed
140

Cody Yu's avatar
Cody Yu committed
141
142
143
async def make_openai_style_logprobs(token_logprobs):
    ret_logprobs = LogProbs()

144
    for token_text, token_logprob in token_logprobs:
Cody Yu's avatar
Cody Yu committed
145
146
147
148
149
150
151
152
153
        ret_logprobs.tokens.append(token_text)
        ret_logprobs.token_logprobs.append(token_logprob)

        # Not supported yet.
        ret_logprobs.top_logprobs.append({})
        ret_logprobs.text_offset.append(-1)
    return ret_logprobs


Lianmin Zheng's avatar
Lianmin Zheng committed
154
155
156
157
158
159
160
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
    obj.post_init()

    if obj.stream:

        async def stream_results():
161
162
163
            async for out in stream_generator(obj):
                yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            yield "data: [DONE]\n\n"
Cody Yu's avatar
Cody Yu committed
164

Lianmin Zheng's avatar
Lianmin Zheng committed
165
        return StreamingResponse(stream_results(), media_type="text/event-stream")
166
167

    ret = await tokenizer_manager.generate_request(obj).__anext__()
168
169
170
171
172
    if obj.return_logprob and obj.return_text_in_logprobs:
        ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
            ret["meta_info"]["token_logprob"]
        )

173
    return ret
Lianmin Zheng's avatar
Lianmin Zheng committed
174
175
176


@app.post("/v1/completions")
177
178
179
180
181
182
183
184
185
async def v1_completions(raw_request: Request):
    request_json = await raw_request.json()
    request = CompletionRequest(**request_json)

    # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
    assert request.n == 1

    adapted_request = GenerateReqInput(
        text=request.prompt,
Lianmin Zheng's avatar
Lianmin Zheng committed
186
        sampling_params={
187
188
189
190
191
192
            "temperature": request.temperature,
            "max_new_tokens": request.max_tokens,
            "stop": request.stop,
            "top_p": request.top_p,
            "presence_penalty": request.presence_penalty,
            "frequency_penalty": request.frequency_penalty,
193
            "regex": request.regex,
Lianmin Zheng's avatar
Lianmin Zheng committed
194
        },
Cody Yu's avatar
Cody Yu committed
195
        return_logprob=request.logprobs is not None,
196
        return_text_in_logprobs=True,
197
        stream=request.stream,
Lianmin Zheng's avatar
Lianmin Zheng committed
198
    )
199
200
201
    adapted_request.post_init()

    if adapted_request.stream:
Cody Yu's avatar
Cody Yu committed
202

203
204
        async def gnerate_stream_resp():
            stream_buffer = ""
Cody Yu's avatar
Cody Yu committed
205
            n_prev_token = 0
206
207
            async for content in stream_generator(adapted_request):
                text = content["text"]
Cody Yu's avatar
Cody Yu committed
208
209
210
                prompt_tokens = content["meta_info"]["prompt_tokens"]
                completion_tokens = content["meta_info"]["completion_tokens"]

211
                if not stream_buffer:  # The first chunk
Cody Yu's avatar
Cody Yu committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
                    if request.echo:
                        # Prepend prompt in response text.
                        text = request.prompt + text
                    else:
                        # Skip prompt tokens if echo is disabled.
                        n_prev_token = prompt_tokens

                if request.logprobs is not None:
                    logprobs = await make_openai_style_logprobs(
                        content["meta_info"]["token_logprob"][n_prev_token:]
                    )
                    n_prev_token = len(content["meta_info"]["token_logprob"])
                else:
                    logprobs = None

Cody Yu's avatar
Cody Yu committed
227
                delta = text[len(stream_buffer) :]
Cody Yu's avatar
Cody Yu committed
228
                stream_buffer = content["text"]
229
230
231
                choice_data = CompletionResponseStreamChoice(
                    index=0,
                    text=delta,
Cody Yu's avatar
Cody Yu committed
232
                    logprobs=logprobs,
233
234
235
236
237
238
239
                    finish_reason=None,
                )
                chunk = CompletionStreamResponse(
                    id=content["meta_info"]["id"],
                    object="text_completion",
                    choices=[choice_data],
                    model=request.model,
Cody Yu's avatar
Cody Yu committed
240
241
242
243
244
                    usage=UsageInfo(
                        prompt_tokens=prompt_tokens,
                        completion_tokens=completion_tokens,
                        total_tokens=prompt_tokens + completion_tokens,
                    ),
245
                )
246
                yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
Cody Yu's avatar
Cody Yu committed
247
            yield "data: [DONE]\n\n"
248
249
250
251
252

        return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")

    # Non-streaming response.
    ret = await generate_request(adapted_request)
253
    ret = ret[0] if isinstance(ret, list) else ret
254

Cody Yu's avatar
Cody Yu committed
255
256
257
258
259
260
261
262
263
264
265
    prompt_tokens = ret["meta_info"]["prompt_tokens"]
    completion_tokens = ret["meta_info"]["completion_tokens"]
    text = ret["text"]
    token_logprob_pos = prompt_tokens
    if request.echo:
        token_logprob_pos = 0
        text = request.prompt + text
    else:
        token_logprob_pos = prompt_tokens

    logprobs = (
266
267
268
        await make_openai_style_logprobs(
            ret["meta_info"]["token_logprob"][token_logprob_pos:]
        )
Cody Yu's avatar
Cody Yu committed
269
270
271
        if request.logprobs is not None
        else None
    )
272
273
    choice_data = CompletionResponseChoice(
        index=0,
Cody Yu's avatar
Cody Yu committed
274
275
        text=text,
        logprobs=logprobs,
Cody Yu's avatar
Cody Yu committed
276
        finish_reason=None,  # TODO(comaniac): Add finish reason.
277
278
279
280
281
282
283
284
285
286
287
288
289
    )

    response = CompletionResponse(
        id=ret["meta_info"]["id"],
        model=request.model,
        choices=[choice_data],
        usage=UsageInfo(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
        ),
    )
    return response
Lianmin Zheng's avatar
Lianmin Zheng committed
290
291


Cody Yu's avatar
Cody Yu committed
292
293
294
295
296
297
298
299
@app.post("/v1/chat/completions")
async def v1_chat_completions(raw_request: Request):
    request_json = await raw_request.json()
    request = ChatCompletionRequest(**request_json)

    # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
    assert request.n == 1

300
301
302
303
304
    # Prep the data needed for the underlying GenerateReqInput:
    #  - prompt: The full prompt string.
    #  - stop: Custom stop tokens.
    #  - image_data: None or a list of image strings (URLs or base64 strings).
    #    None skips any image processing in GenerateReqInput.
Cody Yu's avatar
Cody Yu committed
305
306
307
    if not isinstance(request.messages, str):
        # Apply chat template and its stop strings.
        if chat_template_name is None:
308
309
310
311
312
313
            # This flow doesn't support the full OpenAI spec.  Verify messages
            # has the right type before proceeding:
            for m in request.messages:
                if not isinstance(m.content, str):
                    raise HTTPException(
                        status_code=503,
Cody Yu's avatar
Cody Yu committed
314
315
316
                        detail="Structured content requests not supported with "
                        "HuggingFace Chat Templates. "
                        "Make sure the server specifies a sglang chat template.",
317
                    )
Cody Yu's avatar
Cody Yu committed
318
319
320
321
            prompt = tokenizer_manager.tokenizer.apply_chat_template(
                request.messages, tokenize=False, add_generation_prompt=True
            )
            stop = request.stop
322
            image_data = None
Cody Yu's avatar
Cody Yu committed
323
324
325
        else:
            conv = generate_chat_conv(request, chat_template_name)
            prompt = conv.get_prompt()
326
            image_data = conv.image_data
Cody Yu's avatar
Cody Yu committed
327
328
329
330
331
332
333
334
335
336
            stop = conv.stop_str or []
            if request.stop:
                if isinstance(request.stop, str):
                    stop.append(request.stop)
                else:
                    stop.extend(request.stop)
    else:
        # Use the raw prompt and stop strings if the messages is already a string.
        prompt = request.messages
        stop = request.stop
337
        image_data = None
Cody Yu's avatar
Cody Yu committed
338
339
340

    adapted_request = GenerateReqInput(
        text=prompt,
341
        image_data=image_data,
Cody Yu's avatar
Cody Yu committed
342
343
344
345
346
347
348
        sampling_params={
            "temperature": request.temperature,
            "max_new_tokens": request.max_tokens,
            "stop": stop,
            "top_p": request.top_p,
            "presence_penalty": request.presence_penalty,
            "frequency_penalty": request.frequency_penalty,
349
            "regex": request.regex,
Cody Yu's avatar
Cody Yu committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        },
        stream=request.stream,
    )
    adapted_request.post_init()

    if adapted_request.stream:

        async def gnerate_stream_resp():
            is_first = True

            stream_buffer = ""
            async for content in stream_generator(adapted_request):
                if is_first:
                    # First chunk with role
                    is_first = False
                    choice_data = ChatCompletionResponseStreamChoice(
                        index=0,
                        delta=DeltaMessage(role="assistant"),
                        finish_reason=None,
                    )
                    chunk = ChatCompletionStreamResponse(
371
372
373
                        id=content["meta_info"]["id"],
                        choices=[choice_data],
                        model=request.model,
Cody Yu's avatar
Cody Yu committed
374
                    )
375
                    yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
Cody Yu's avatar
Cody Yu committed
376
377
378
379
380
381
382
383

                text = content["text"]
                delta = text[len(stream_buffer) :]
                stream_buffer = text
                choice_data = ChatCompletionResponseStreamChoice(
                    index=0, delta=DeltaMessage(content=delta), finish_reason=None
                )
                chunk = ChatCompletionStreamResponse(
384
385
386
                    id=content["meta_info"]["id"],
                    choices=[choice_data],
                    model=request.model,
Cody Yu's avatar
Cody Yu committed
387
                )
388
                yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
Cody Yu's avatar
Cody Yu committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
            yield "data: [DONE]\n\n"

        return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")

    # Non-streaming response.
    ret = await generate_request(adapted_request)
    prompt_tokens = ret["meta_info"]["prompt_tokens"]
    completion_tokens = ret["meta_info"]["completion_tokens"]
    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=ChatMessage(role="assistant", content=ret["text"]),
        finish_reason=None,  # TODO(comaniac): Add finish reason.
    )
    response = ChatCompletionResponse(
        id=ret["meta_info"]["id"],
        model=request.model,
        choices=[choice_data],
        usage=UsageInfo(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
        ),
    )
    return response


Lianmin Zheng's avatar
Lianmin Zheng committed
415
416
def launch_server(server_args, pipe_finish_writer):
    global tokenizer_manager
Cody Yu's avatar
Cody Yu committed
417
    global chat_template_name
Lianmin Zheng's avatar
Lianmin Zheng committed
418

419
420
421
422
    # disable disk cache if needed
    if server_args.disable_disk_cache:
        disable_cache()

423
424
425
    # Handle ports
    server_args.port, server_args.additional_ports = handle_port_init(
        server_args.port, server_args.additional_ports, server_args.tp_size
Lianmin Zheng's avatar
Lianmin Zheng committed
426
    )
427

Lianmin Zheng's avatar
Lianmin Zheng committed
428
    port_args = PortArgs(
429
430
431
432
433
        tokenizer_port=server_args.additional_ports[0],
        router_port=server_args.additional_ports[1],
        detokenizer_port=server_args.additional_ports[2],
        nccl_port=server_args.additional_ports[3],
        model_rpc_ports=server_args.additional_ports[4:],
Lianmin Zheng's avatar
Lianmin Zheng committed
434
435
    )

Cody Yu's avatar
Cody Yu committed
436
437
    # Load chat template if needed
    if server_args.chat_template is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
438
        print(f"Use chat template: {server_args.chat_template}")
Cody Yu's avatar
Cody Yu committed
439
440
441
442
443
444
445
446
447
448
449
        if not chat_template_exists(server_args.chat_template):
            if not os.path.exists(server_args.chat_template):
                raise RuntimeError(
                    f"Chat template {server_args.chat_template} is not a built-in template name "
                    "or a valid chat template file path."
                )
            with open(server_args.chat_template, "r") as filep:
                template = json.load(filep)
                try:
                    sep_style = SeparatorStyle[template["sep_style"]]
                except KeyError:
450
451
452
                    raise ValueError(
                        f"Unknown separator style: {template['sep_style']}"
                    ) from None
Cody Yu's avatar
Cody Yu committed
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
                register_conv_template(
                    Conversation(
                        name=template["name"],
                        system_template=template["system"] + "\n{system_message}",
                        system_message=template.get("system_message", ""),
                        roles=(template["user"], template["assistant"]),
                        sep_style=sep_style,
                        sep=template.get("sep", "\n"),
                        stop_str=template["stop_str"],
                    ),
                    override=True,
                )
            chat_template_name = template["name"]
        else:
            chat_template_name = server_args.chat_template

Lianmin Zheng's avatar
Lianmin Zheng committed
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    # Launch processes
    tokenizer_manager = TokenizerManager(server_args, port_args)
    pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
    pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)

    proc_router = mp.Process(
        target=start_router_process,
        args=(
            server_args,
            port_args,
            pipe_router_writer,
        ),
    )
    proc_router.start()
    proc_detoken = mp.Process(
        target=start_detokenizer_process,
        args=(
            server_args,
            port_args,
            pipe_detoken_writer,
        ),
    )
    proc_detoken.start()

    # Wait for the model to finish loading
    router_init_state = pipe_router_reader.recv()
    detoken_init_state = pipe_detoken_reader.recv()

    if router_init_state != "init ok" or detoken_init_state != "init ok":
        proc_router.kill()
        proc_detoken.kill()
        print("router init state:", router_init_state)
        print("detoken init state:", detoken_init_state)
        sys.exit(1)

    assert proc_router.is_alive() and proc_detoken.is_alive()

506
507
508
    if server_args.api_key and server_args.api_key != "":
        app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)

Cody Yu's avatar
Cody Yu committed
509
    def _launch_server():
Lianmin Zheng's avatar
Lianmin Zheng committed
510
511
512
513
514
515
516
517
518
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
            log_level=server_args.log_level,
            timeout_keep_alive=5,
            loop="uvloop",
        )

519
    def _wait_and_warmup():
520
        headers = {}
521
        url = server_args.url()
522
523
524
525
526
        if server_args.api_key and server_args.api_key != "":
            headers[API_KEY_HEADER_NAME] = server_args.api_key

        for _ in range(120):
            time.sleep(0.5)
527
            try:
528
                requests.get(url + "/get_model_info", timeout=5, headers=headers)
529
530
531
532
533
534
535
536
537
                break
            except requests.exceptions.RequestException as e:
                pass
        else:
            if pipe_finish_writer is not None:
                pipe_finish_writer.send(str(e))
            else:
                print(e, flush=True)
            return
Lianmin Zheng's avatar
Lianmin Zheng committed
538

539
        # Warmup
Cody Yu's avatar
Cody Yu committed
540
        try:
541
542
543
544
545
546
547
548
549
550
            # print("Warmup...", flush=True)
            res = requests.post(
                url + "/generate",
                json={
                    "text": "Say this is a warmup request.",
                    "sampling_params": {
                        "temperature": 0,
                        "max_new_tokens": 16,
                    },
                },
551
                headers=headers,
552
553
554
555
                timeout=60,
            )
            # print(f"Warmup done. model response: {res.json()['text']}")
            # print("=" * 20, "Server is ready", "=" * 20, flush=True)
Cody Yu's avatar
Cody Yu committed
556
        except requests.exceptions.RequestException as e:
557
558
559
560
561
            if pipe_finish_writer is not None:
                pipe_finish_writer.send(str(e))
            else:
                print(e, flush=True)
            return
Cody Yu's avatar
Cody Yu committed
562
563

        if pipe_finish_writer is not None:
564
            pipe_finish_writer.send("init ok")
Cody Yu's avatar
Cody Yu committed
565

566
567
568
569
570
571
    t = threading.Thread(target=_wait_and_warmup)
    t.start()
    try:
        _launch_server()
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
572
573
574
575
576
577
578
579
580
581


class Runtime:
    def __init__(
        self,
        model_path: str,
        tokenizer_path: Optional[str] = None,
        load_format: str = "auto",
        tokenizer_mode: str = "auto",
        trust_remote_code: bool = True,
582
583
        mem_fraction_static: float = ServerArgs.mem_fraction_static,
        max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
584
        context_length: int = ServerArgs.context_length,
Lianmin Zheng's avatar
Lianmin Zheng committed
585
586
        tp_size: int = 1,
        schedule_heuristic: str = "lpm",
587
        attention_reduce_in_fp32: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
588
        random_seed: int = 42,
589
        log_level: str = "error",
590
591
592
593
        disable_radix_cache: bool = False,
        enable_flashinfer: bool = False,
        disable_regex_jump_forward: bool = False,
        disable_disk_cache: bool = False,
594
        api_key: str = "",
595
596
        port: Optional[int] = None,
        additional_ports: Optional[Union[List[int], int]] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
597
598
    ):
        host = "127.0.0.1"
Lianmin Zheng's avatar
Lianmin Zheng committed
599
        port, additional_ports = handle_port_init(port, additional_ports, tp_size)
Ying Sheng's avatar
Ying Sheng committed
600
        self.server_args = ServerArgs(
Lianmin Zheng's avatar
Lianmin Zheng committed
601
602
603
604
            model_path=model_path,
            tokenizer_path=tokenizer_path,
            host=host,
            port=port,
605
            additional_ports=additional_ports,
Lianmin Zheng's avatar
Lianmin Zheng committed
606
607
608
609
            load_format=load_format,
            tokenizer_mode=tokenizer_mode,
            trust_remote_code=trust_remote_code,
            mem_fraction_static=mem_fraction_static,
610
            max_prefill_num_token=max_prefill_num_token,
611
            context_length=context_length,
Lianmin Zheng's avatar
Lianmin Zheng committed
612
613
            tp_size=tp_size,
            schedule_heuristic=schedule_heuristic,
614
            attention_reduce_in_fp32=attention_reduce_in_fp32,
Lianmin Zheng's avatar
Lianmin Zheng committed
615
616
            random_seed=random_seed,
            log_level=log_level,
617
618
619
620
            disable_radix_cache=disable_radix_cache,
            enable_flashinfer=enable_flashinfer,
            disable_regex_jump_forward=disable_regex_jump_forward,
            disable_disk_cache=disable_disk_cache,
621
            api_key=api_key,
Lianmin Zheng's avatar
Lianmin Zheng committed
622
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
623

Ying Sheng's avatar
Ying Sheng committed
624
625
626
627
        self.url = self.server_args.url()
        self.generate_url = (
            f"http://{self.server_args.host}:{self.server_args.port}/generate"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
630

        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
Ying Sheng's avatar
Ying Sheng committed
631
        proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
Lianmin Zheng's avatar
Lianmin Zheng committed
632
        proc.start()
633
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
634
635
        self.pid = proc.pid

636
637
638
639
640
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

Lianmin Zheng's avatar
Lianmin Zheng committed
641
642
        if init_state != "init ok":
            self.shutdown()
643
            raise RuntimeError("Launch failed. Please see the error messages above.")
Lianmin Zheng's avatar
Lianmin Zheng committed
644
645
646
647
648

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
649
650
651
652
            try:
                parent = psutil.Process(self.pid)
            except psutil.NoSuchProcess:
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
653
654
655
656
657
658
659
660
            children = parent.children(recursive=True)
            for child in children:
                child.kill()
            psutil.wait_procs(children, timeout=5)
            parent.kill()
            parent.wait(timeout=5)
            self.pid = None

Ying Sheng's avatar
Ying Sheng committed
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

    async def add_request(
        self,
        prompt: str,
        sampling_params,
    ) -> None:
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "stream": True,
        }

        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
                        cur = data["text"][pos:]
                        if cur:
                            yield cur
                        pos += len(cur)

Lianmin Zheng's avatar
Lianmin Zheng committed
695
696
    def __del__(self):
        self.shutdown()