server.py 30.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
16
17
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
18

Lianmin Zheng's avatar
Lianmin Zheng committed
19
import asyncio
20
import atexit
Liangsheng Yin's avatar
Liangsheng Yin committed
21
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import json
23
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
24
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
25
import os
26
27
import signal
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
import threading
import time
30
from http import HTTPStatus
31
32
from typing import AsyncIterator, Dict, List, Optional, Union

33
# Fix a bug of Python threading
34
35
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
36
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
37
import orjson
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
40
import requests
import uvicorn
import uvloop
41
from fastapi import FastAPI, File, Form, Request, UploadFile
42
from fastapi.middleware.cors import CORSMiddleware
43
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
from uvicorn.config import LOGGING_CONFIG
Liangsheng Yin's avatar
Liangsheng Yin committed
45

Ying Sheng's avatar
Ying Sheng committed
46
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Ying Sheng's avatar
Ying Sheng committed
47
from sglang.srt.hf_transformers_utils import get_tokenizer
48
49
50
from sglang.srt.managers.data_parallel_controller import (
    run_data_parallel_controller_process,
)
51
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
52
from sglang.srt.managers.io_struct import (
53
    CloseSessionReqInput,
54
55
    EmbeddingReqInput,
    GenerateReqInput,
56
    OpenSessionReqInput,
57
58
    UpdateWeightReqInput,
)
59
from sglang.srt.managers.scheduler import run_scheduler_process
Lianmin Zheng's avatar
Lianmin Zheng committed
60
from sglang.srt.managers.tokenizer_manager import TokenizerManager
61
from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency
Mingyi's avatar
Mingyi committed
62
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
63
    load_chat_template_for_openai_api,
64
    v1_batches,
65
    v1_cancel_batch,
Liangsheng Yin's avatar
Liangsheng Yin committed
66
67
    v1_chat_completions,
    v1_completions,
68
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
69
    v1_embeddings,
70
71
72
73
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
74
)
Mingyi's avatar
Mingyi committed
75
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
76
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
77
from sglang.srt.utils import (
78
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    add_prometheus_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    assert_pkg_version,
81
    configure_logger,
82
    delete_directory,
83
    is_port_available,
84
    kill_process_tree,
85
    maybe_set_triton_cache_manager,
86
    prepare_model_and_tokenizer,
Lianmin Zheng's avatar
Lianmin Zheng committed
87
    set_prometheus_multiproc_dir,
88
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
89
)
90
from sglang.utils import get_exception_traceback
91
from sglang.version import __version__
92

93
94
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
95
96
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
97

Lianmin Zheng's avatar
Lianmin Zheng committed
98
app = FastAPI()
99
100
101
102
103
104
105
106
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

107
tokenizer_manager: TokenizerManager = None
108
_max_total_num_tokens = None
109
110
111

##### Native API endpoints #####

Lianmin Zheng's avatar
Lianmin Zheng committed
112

113
114
115
116
117
118
119
120
121
@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""
122
123
124
125
126
127
128
129
130
131

    if tokenizer_manager.is_generation:
        gri = GenerateReqInput(
            input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
        )
    else:
        gri = EmbeddingReqInput(
            input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
        )

132
133
134
135
136
137
138
139
140
    try:
        async for _ in tokenizer_manager.generate_request(gri, request):
            break
        return Response(status_code=200)
    except Exception as e:
        logger.exception(e)
        return Response(status_code=503)


Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
@app.get("/get_model_info")
async def get_model_info():
143
    """Get the model information."""
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
    result = {
        "model_path": tokenizer_manager.model_path,
Lianmin Zheng's avatar
Lianmin Zheng committed
146
        "tokenizer_path": tokenizer_manager.server_args.tokenizer_path,
147
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
148
149
150
    }
    return result

Cody Yu's avatar
Cody Yu committed
151

152
153
154
155
156
157
158
159
160
@app.get("/get_server_info")
async def get_server_info():
    try:
        return await _get_server_info()

    except Exception as e:
        return ORJSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
161
162


Lianmin Zheng's avatar
Lianmin Zheng committed
163
@app.post("/flush_cache")
Liangsheng Yin's avatar
Liangsheng Yin committed
164
async def flush_cache():
165
    """Flush the radix cache."""
166
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
167
    return Response(
168
169
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
170
171
172
173
        status_code=200,
    )


174
175
176
177
178
179
180
181
182
183
def start_profile():
    """Start profiling."""
    tokenizer_manager.start_profile()


def stop_profile():
    """Stop profiling."""
    tokenizer_manager.stop_profile()


184
185
@app.get("/start_profile")
@app.post("/start_profile")
186
async def start_profile_async():
187
188
189
190
191
192
193
194
195
196
    """Start profiling."""
    tokenizer_manager.start_profile()
    return Response(
        content="Start profiling.\n",
        status_code=200,
    )


@app.get("/stop_profile")
@app.post("/stop_profile")
197
async def stop_profile_async():
198
199
200
201
202
203
204
205
    """Stop profiling."""
    tokenizer_manager.stop_profile()
    return Response(
        content="Stop profiling. This will take some time.\n",
        status_code=200,
    )


206
@app.post("/update_weights")
207
@time_func_latency
208
async def update_weights(obj: UpdateWeightReqInput, request: Request):
209
    """Update the weights inplace without re-launching the server."""
210
    success, message = await tokenizer_manager.update_weights(obj, request)
Lianmin Zheng's avatar
Lianmin Zheng committed
211
    content = {"success": success, "message": message}
212
    if success:
213
        return ORJSONResponse(
214
215
216
217
            content,
            status_code=HTTPStatus.OK,
        )
    else:
218
        return ORJSONResponse(
219
220
221
222
223
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
    """Open a session, and return its unique session id."""
    try:
        session_id = await tokenizer_manager.open_session(obj, request)
        return session_id
    except Exception as e:
        return ORJSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


@app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request):
    """Close the session"""
    try:
        await tokenizer_manager.close_session(obj, request)
        return Response(status_code=200)
    except Exception as e:
        return ORJSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


248
@time_func_latency
249
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
250
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
251
    if obj.stream:
252

253
        async def stream_results() -> AsyncIterator[bytes]:
254
255
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
256
257
258
                    yield b"data: " + orjson.dumps(
                        out, option=orjson.OPT_NON_STR_KEYS
                    ) + b"\n\n"
259
260
            except ValueError as e:
                out = {"error": {"message": str(e)}}
261
262
263
264
                yield b"data: " + orjson.dumps(
                    out, option=orjson.OPT_NON_STR_KEYS
                ) + b"\n\n"
            yield b"data: [DONE]\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
265

266
267
268
269
270
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
271
272
273
274
275
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
276
            return ORJSONResponse(
277
278
279
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
280

281
# fastapi implicitly converts json in the request to obj (dataclass)
Ying Sheng's avatar
Ying Sheng committed
282
283
284
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
285

286
@time_func_latency
287
288
289
290
291
292
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
293
        return ORJSONResponse(
294
295
296
297
298
299
300
301
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


302
@time_func_latency
303
async def classify_request(obj: EmbeddingReqInput, request: Request):
304
    """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
305
306
307
308
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
309
        return ORJSONResponse(
310
311
312
313
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


314
315
app.post("/classify")(classify_request)
app.put("/classify")(classify_request)
316
317


318
319
320
##### OpenAI-compatible API endpoints #####


Lianmin Zheng's avatar
Lianmin Zheng committed
321
@app.post("/v1/completions")
322
@time_func_latency
323
324
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
325
326


Cody Yu's avatar
Cody Yu committed
327
@app.post("/v1/chat/completions")
328
@time_func_latency
329
330
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
331

Lianmin Zheng's avatar
Lianmin Zheng committed
332

333
@app.post("/v1/embeddings", response_class=ORJSONResponse)
334
@time_func_latency
Ying Sheng's avatar
Ying Sheng committed
335
336
337
338
339
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


340
@app.get("/v1/models", response_class=ORJSONResponse)
341
342
343
344
345
346
347
348
349
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


350
351
352
353
354
355
356
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


357
358
359
360
361
362
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


363
364
365
366
367
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


368
369
370
371
372
373
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
    # https://platform.openai.com/docs/api-reference/batch/cancel
    return await v1_cancel_batch(tokenizer_manager, batch_id)


374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


391
def launch_engine(
zhyncs's avatar
zhyncs committed
392
393
    server_args: ServerArgs,
):
394
395
396
397
    """
    Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
398
    global tokenizer_manager
399
    global _max_total_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
400

401
    # Configure global environment
402
    configure_logger(server_args)
403
404
    server_args.check_server_args()
    _set_envs_and_config(server_args)
405

406
    # Allocate ports for inter-process communications
407
    port_args = PortArgs.init_new(server_args)
408
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
409

410
411
412
413
    # If using model from www.modelscope.cn, first download the model.
    server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
        server_args.model_path, server_args.tokenizer_path
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
414

415
416
417
418
419
420
421
422
423
424
425
    if server_args.dp_size == 1:
        # Launch tensor parallel scheduler processes
        scheduler_procs = []
        scheduler_pipe_readers = []
        tp_size_per_node = server_args.tp_size // server_args.nnodes
        tp_rank_range = range(
            tp_size_per_node * server_args.node_rank,
            tp_size_per_node * (server_args.node_rank + 1),
        )
        for tp_rank in tp_rank_range:
            reader, writer = mp.Pipe(duplex=False)
426
            gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
            proc = mp.Process(
                target=run_scheduler_process,
                args=(server_args, port_args, gpu_id, tp_rank, None, writer),
            )
            proc.start()
            scheduler_procs.append(proc)
            scheduler_pipe_readers.append(reader)

        if server_args.node_rank >= 1:
            # For other nodes, they do not need to run tokenizer or detokenizer,
            # so they can just wait here.
            while True:
                pass
    else:
        # Launch the data parallel controller
442
        reader, writer = mp.Pipe(duplex=False)
443
        scheduler_pipe_readers = [reader]
444
        proc = mp.Process(
445
446
            target=run_data_parallel_controller_process,
            args=(server_args, port_args, writer),
447
448
        )
        proc.start()
449

450
451
452
    # Launch detokenizer process
    detoken_proc = mp.Process(
        target=run_detokenizer_process,
Lianmin Zheng's avatar
Lianmin Zheng committed
453
454
455
456
457
        args=(
            server_args,
            port_args,
        ),
    )
458
    detoken_proc.start()
Lianmin Zheng's avatar
Lianmin Zheng committed
459

460
    # Launch tokenizer process
461
462
463
464
    tokenizer_manager = TokenizerManager(server_args, port_args)
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)

465
466
    # Wait for model to finish loading & get max token nums
    scheduler_info = []
467
    for i in range(len(scheduler_pipe_readers)):
468
469
470
471
472
473
474
475
476
477
        data = scheduler_pipe_readers[i].recv()

        if data["status"] != "ready":
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
        scheduler_info.append(data)

    # Assume all schedulers have same max_total_num_tokens
    _max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
Lianmin Zheng's avatar
Lianmin Zheng committed
478

Chayenne's avatar
Chayenne committed
479

480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
def launch_server(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
    """
    Launch SRT (SGLang Runtime) Server

    The SRT server consists of an HTTP server and the SRT engine.

    1. HTTP server: A FastAPI server that routes requests to the engine.
    2. SRT engine:
        1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
        3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.

    Note:
    1. The HTTP server and Tokenizer Manager both run in the main process.
    2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
    """
    launch_engine(server_args=server_args)

501
502
503
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
504

505
506
507
    # add prometheus middleware
    if server_args.enable_metrics:
        add_prometheus_middleware(app)
508
        enable_func_timer()
509

510
    # Send a warmup request
zhyncs's avatar
zhyncs committed
511
    t = threading.Thread(
Lianmin Zheng's avatar
Lianmin Zheng committed
512
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
zhyncs's avatar
zhyncs committed
513
    )
514
    t.start()
515

516
    try:
517
        # Listen for HTTP requests
518
519
520
        LOGGING_CONFIG["formatters"]["default"][
            "fmt"
        ] = "[%(asctime)s] %(levelprefix)s %(message)s"
521
        LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
522
523
524
        LOGGING_CONFIG["formatters"]["access"][
            "fmt"
        ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
525
        LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
Lianmin Zheng's avatar
Lianmin Zheng committed
526
527
528
529
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
530
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
531
532
533
            timeout_keep_alive=5,
            loop="uvloop",
        )
534
535
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
536

Chayenne's avatar
Chayenne committed
537

538
539
540
541
542
async def _get_server_info():
    return {
        **dataclasses.asdict(tokenizer_manager.server_args),  # server args
        "memory_pool_size": await tokenizer_manager.get_memory_pool_size(),  # memory pool size
        "max_total_num_tokens": _max_total_num_tokens,  # max total num tokens
543
        "version": __version__,
544
    }
545
546


547
548
549
550
551
552
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
553
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
554

555
556
557
558
    # Set prometheus env vars
    if server_args.enable_metrics:
        set_prometheus_multiproc_dir()

559
560
561
562
563
564
565
566
567
    # Set ulimit
    set_ulimit()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
568
    if server_args.attention_backend == "flashinfer":
569
570
        assert_pkg_version(
            "flashinfer",
571
            "0.1.6",
572
573
574
575
576
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )

577
578
579
580
581
582
583
584
585
    # Register the signal handler.
    # The child processes will send SIGQUIT to this process when any error happens
    # This process then clean up the whole process tree
    def sigquit_handler(signum, frame):
        kill_process_tree(os.getpid())

    signal.signal(signal.SIGQUIT, sigquit_handler)

    # Set mp start method
586
    mp.set_start_method("spawn", force=True)
587

588

Lianmin Zheng's avatar
Lianmin Zheng committed
589
def _wait_and_warmup(server_args, pipe_finish_writer):
Mingyi's avatar
Mingyi committed
590
591
592
    headers = {}
    url = server_args.url()
    if server_args.api_key:
593
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
594
595

    # Wait until the server is launched
596
    success = False
Mingyi's avatar
Mingyi committed
597
    for _ in range(120):
598
        time.sleep(1)
Mingyi's avatar
Mingyi committed
599
        try:
600
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
601
            assert res.status_code == 200, f"{res=}, {res.text=}"
602
            success = True
Mingyi's avatar
Mingyi committed
603
            break
604
        except (AssertionError, requests.exceptions.RequestException):
605
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
606
607
            pass

608
609
610
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
611
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
612
        kill_process_tree(os.getpid())
613
        return
614

615
    model_info = res.json()
616

Mingyi's avatar
Mingyi committed
617
    # Send a warmup request
618
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
619
    max_new_tokens = 8 if model_info["is_generation"] else 1
620
621
622
623
624
625
626
627
628
629
630
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
631
632
633
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
634
                url + request_name,
635
                json=json_data,
Mingyi's avatar
Mingyi committed
636
637
638
                headers=headers,
                timeout=600,
            )
639
            assert res.status_code == 200, f"{res}"
640
    except Exception:
641
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
642
        if pipe_finish_writer is not None:
643
            pipe_finish_writer.send(last_traceback)
644
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
645
        kill_process_tree(os.getpid())
646
        return
Mingyi's avatar
Mingyi committed
647

648
649
    # logger.info(f"{res.json()=}")

Mingyi's avatar
Mingyi committed
650
651
    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
652
        pipe_finish_writer.send("ready")
Mingyi's avatar
Mingyi committed
653

654
655
656
    if server_args.delete_ckpt_after_loading:
        delete_directory(server_args.model_path)

Mingyi's avatar
Mingyi committed
657

Lianmin Zheng's avatar
Lianmin Zheng committed
658
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
659
660
661
662
663
664
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
665
666
    def __init__(
        self,
667
        log_level: str = "error",
Lianmin Zheng's avatar
Lianmin Zheng committed
668
669
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
670
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
671
        """See the arguments in server_args.py::ServerArgs"""
672
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
673

674
675
676
        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)

Lianmin Zheng's avatar
Lianmin Zheng committed
677
        # Pre-allocate ports
678
679
680
681
682
        for port in range(10000, 40000):
            if is_port_available(port):
                break
            port += 1
        self.server_args.port = port
Lianmin Zheng's avatar
Lianmin Zheng committed
683

Ying Sheng's avatar
Ying Sheng committed
684
        self.url = self.server_args.url()
685
        self.generate_url = self.url + "/generate"
Lianmin Zheng's avatar
Lianmin Zheng committed
686

687
        # NOTE: We store pid instead of proc to fix some issues during __delete__
Lianmin Zheng's avatar
Lianmin Zheng committed
688
689
        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
690

Yuanhan Zhang's avatar
Yuanhan Zhang committed
691
692
        proc = mp.Process(
            target=launch_server,
Lianmin Zheng's avatar
Lianmin Zheng committed
693
            args=(self.server_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
694
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
695
        proc.start()
696
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
697
698
        self.pid = proc.pid

699
700
701
702
703
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

704
        if init_state != "ready":
Lianmin Zheng's avatar
Lianmin Zheng committed
705
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
706
707
708
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
709
710
711
712
713

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
714
            kill_process_tree(self.pid)
Lianmin Zheng's avatar
Lianmin Zheng committed
715
716
            self.pid = None

717
718
719
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
720
721
722
723
724
725
726
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

727
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
728
729
        self,
        prompt: str,
730
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
731
    ):
732
733
734
735
736
737
738
739
740
741
742
743
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
744
745
746
747
748
749
750
751
752
753
754
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
Lianmin Zheng's avatar
Lianmin Zheng committed
755
                        if "text" in data:
756
757
758
759
760
761
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
762

763
764
765
766
    add_request = async_generate

    def generate(
        self,
767
        prompt: Union[str, List[str]],
768
769
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
770
        logprob_start_len: Optional[Union[List[int], int]] = None,
771
        top_logprobs_num: Optional[Union[List[int], int]] = None,
772
        lora_path: Optional[List[Optional[str]]] = None,
773
774
775
776
777
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
778
            "logprob_start_len": logprob_start_len,
779
            "top_logprobs_num": top_logprobs_num,
780
            "lora_path": lora_path,
781
        }
782
        assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
783
784
785
786
787
788
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

789
790
    def encode(
        self,
791
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
792
    ):
793
794
        json_data = {"text": prompt}
        response = requests.post(self.url + "/encode", json=json_data)
795
796
        return json.dumps(response.json())

797
798
799
800
801
802
803
804
805
806
    async def get_server_info(self):
        async with aiohttp.ClientSession() as session:
            async with session.get(f"{self.url}/get_server_info") as response:
                if response.status == 200:
                    return await response.json()
                else:
                    error_data = await response.json()
                    raise RuntimeError(
                        f"Failed to get server info. {error_data['error']['message']}"
                    )
807

Lianmin Zheng's avatar
Lianmin Zheng committed
808
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
809
        self.shutdown()
810
811


Lianmin Zheng's avatar
Lianmin Zheng committed
812
813
814
815
STREAM_END_SYMBOL = b"data: [DONE]"
STREAM_CHUNK_START_SYMBOL = b"data:"


816
817
818
819
820
821
822
823
824
825
826
827
class Engine:
    """
    SRT Engine without an HTTP server layer.

    This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
    launching the HTTP server adds unnecessary complexity or overhead,
    """

    def __init__(self, *args, **kwargs):

        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)
828

Chayenne's avatar
Chayenne committed
829
830
831
        # runtime server default log level is log
        # offline engine works in scripts, so we set it to error

Chayenne's avatar
Chayenne committed
832
833
        if "log_level" not in kwargs:
            kwargs["log_level"] = "error"
834
835
836
837
838
839

        server_args = ServerArgs(*args, **kwargs)
        launch_engine(server_args=server_args)

    def generate(
        self,
840
841
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
842
        sampling_params: Optional[Union[List[Dict], Dict]] = None,
843
844
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
845
846
847
848
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
849
        stream: bool = False,
850
851
852
    ):
        obj = GenerateReqInput(
            text=prompt,
853
            input_ids=input_ids,
854
855
856
857
858
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
859
            stream=stream,
860
861
        )

862
863
        # get the current event loop
        loop = asyncio.get_event_loop()
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
        ret = loop.run_until_complete(generate_request(obj, None))

        if stream is True:

            def generator_wrapper():
                offset = 0
                loop = asyncio.get_event_loop()
                generator = ret.body_iterator
                while True:
                    chunk = loop.run_until_complete(generator.__anext__())

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
            # however, it allows to wrap the generator as a subfunction and return
            return generator_wrapper()
        else:
            return ret

    async def async_generate(
        self,
891
892
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
893
        sampling_params: Optional[Dict] = None,
894
895
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
896
897
898
899
900
901
902
903
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
        stream: bool = False,
    ):
        obj = GenerateReqInput(
            text=prompt,
904
            input_ids=input_ids,
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
            stream=stream,
        )

        ret = await generate_request(obj, None)

        if stream is True:
            generator = ret.body_iterator

            async def generator_wrapper():

                offset = 0

                while True:
                    chunk = await generator.__anext__()

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            return generator_wrapper()
        else:
            return ret
936
937

    def shutdown(self):
938
        kill_process_tree(os.getpid(), include_parent=False)
939

940
941
942
943
944
945
946
947
    def get_tokenizer(self):
        global tokenizer_manager

        if tokenizer_manager is None:
            raise ReferenceError("Tokenizer Manager is not initialized.")
        else:
            return tokenizer_manager.tokenizer

James Xu's avatar
James Xu committed
948
949
950
951
952
953
954
955
956
    def encode(
        self,
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
    ):
        obj = EmbeddingReqInput(text=prompt)

        # get the current event loop
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(encode_request(obj, None))
957

958
959
    async def get_server_info(self):
        return await _get_server_info()