server.py 30.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
15
16
17
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
18

Lianmin Zheng's avatar
Lianmin Zheng committed
19
import asyncio
20
import atexit
Liangsheng Yin's avatar
Liangsheng Yin committed
21
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import json
23
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
24
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
25
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
import threading
import time
28
from http import HTTPStatus
29
30
from typing import AsyncIterator, Dict, List, Optional, Union

31
# Fix a bug of Python threading
32
33
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
34
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
35
import orjson
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
import requests
import uvicorn
import uvloop
39
from fastapi import FastAPI, File, Form, Request, UploadFile
40
from fastapi.middleware.cors import CORSMiddleware
41
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
42
from uvicorn.config import LOGGING_CONFIG
Liangsheng Yin's avatar
Liangsheng Yin committed
43

Ying Sheng's avatar
Ying Sheng committed
44
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Ying Sheng's avatar
Ying Sheng committed
45
from sglang.srt.hf_transformers_utils import get_tokenizer
46
47
48
from sglang.srt.managers.data_parallel_controller import (
    run_data_parallel_controller_process,
)
49
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
50
from sglang.srt.managers.io_struct import (
51
    CloseSessionReqInput,
52
53
    EmbeddingReqInput,
    GenerateReqInput,
54
    OpenSessionReqInput,
55
56
    UpdateWeightReqInput,
)
57
from sglang.srt.managers.scheduler import run_scheduler_process
Lianmin Zheng's avatar
Lianmin Zheng committed
58
from sglang.srt.managers.tokenizer_manager import TokenizerManager
59
from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency
Mingyi's avatar
Mingyi committed
60
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
61
    load_chat_template_for_openai_api,
62
    v1_batches,
63
    v1_cancel_batch,
Liangsheng Yin's avatar
Liangsheng Yin committed
64
65
    v1_chat_completions,
    v1_completions,
66
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
67
    v1_embeddings,
68
69
70
71
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
72
)
Mingyi's avatar
Mingyi committed
73
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
74
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
75
from sglang.srt.utils import (
76
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
    add_prometheus_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
    assert_pkg_version,
79
    configure_logger,
80
    delete_directory,
81
    is_port_available,
82
    kill_child_process,
83
    maybe_set_triton_cache_manager,
84
    prepare_model_and_tokenizer,
Lianmin Zheng's avatar
Lianmin Zheng committed
85
    set_prometheus_multiproc_dir,
86
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
87
)
88
89
from sglang.utils import get_exception_traceback

90
91
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
94

Lianmin Zheng's avatar
Lianmin Zheng committed
95
app = FastAPI()
96
97
98
99
100
101
102
103
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

104
tokenizer_manager: TokenizerManager = None
105
_max_total_num_tokens = None
106
107
108

##### Native API endpoints #####

Lianmin Zheng's avatar
Lianmin Zheng committed
109

110
111
112
113
114
115
116
117
118
@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""
119
120
121
122
123
124
125
126
127
128

    if tokenizer_manager.is_generation:
        gri = GenerateReqInput(
            input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
        )
    else:
        gri = EmbeddingReqInput(
            input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
        )

129
130
131
132
133
134
135
136
137
    try:
        async for _ in tokenizer_manager.generate_request(gri, request):
            break
        return Response(status_code=200)
    except Exception as e:
        logger.exception(e)
        return Response(status_code=503)


Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
@app.get("/get_model_info")
async def get_model_info():
140
    """Get the model information."""
Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
    result = {
        "model_path": tokenizer_manager.model_path,
Lianmin Zheng's avatar
Lianmin Zheng committed
143
        "tokenizer_path": tokenizer_manager.server_args.tokenizer_path,
144
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
145
146
147
    }
    return result

Cody Yu's avatar
Cody Yu committed
148

149
150
151
152
153
154
155
156
157
@app.get("/get_server_info")
async def get_server_info():
    try:
        return await _get_server_info()

    except Exception as e:
        return ORJSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
158
159


Lianmin Zheng's avatar
Lianmin Zheng committed
160
@app.post("/flush_cache")
Liangsheng Yin's avatar
Liangsheng Yin committed
161
async def flush_cache():
162
    """Flush the radix cache."""
163
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
164
    return Response(
165
166
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
167
168
169
170
        status_code=200,
    )


171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile():
    """Start profiling."""
    tokenizer_manager.start_profile()
    return Response(
        content="Start profiling.\n",
        status_code=200,
    )


@app.get("/stop_profile")
@app.post("/stop_profile")
async def stop_profile():
    """Stop profiling."""
    tokenizer_manager.stop_profile()
    return Response(
        content="Stop profiling. This will take some time.\n",
        status_code=200,
    )


193
@app.post("/update_weights")
194
@time_func_latency
195
async def update_weights(obj: UpdateWeightReqInput, request: Request):
196
    """Update the weights inplace without re-launching the server."""
197
    success, message = await tokenizer_manager.update_weights(obj, request)
Lianmin Zheng's avatar
Lianmin Zheng committed
198
    content = {"success": success, "message": message}
199
    if success:
200
        return ORJSONResponse(
201
202
203
204
            content,
            status_code=HTTPStatus.OK,
        )
    else:
205
        return ORJSONResponse(
206
207
208
209
210
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
    """Open a session, and return its unique session id."""
    try:
        session_id = await tokenizer_manager.open_session(obj, request)
        return session_id
    except Exception as e:
        return ORJSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


@app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request):
    """Close the session"""
    try:
        await tokenizer_manager.close_session(obj, request)
        return Response(status_code=200)
    except Exception as e:
        return ORJSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


235
@time_func_latency
236
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
237
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
238
    if obj.stream:
239

240
        async def stream_results() -> AsyncIterator[bytes]:
241
242
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
243
244
245
                    yield b"data: " + orjson.dumps(
                        out, option=orjson.OPT_NON_STR_KEYS
                    ) + b"\n\n"
246
247
            except ValueError as e:
                out = {"error": {"message": str(e)}}
248
249
250
251
                yield b"data: " + orjson.dumps(
                    out, option=orjson.OPT_NON_STR_KEYS
                ) + b"\n\n"
            yield b"data: [DONE]\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
252

253
254
255
256
257
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
258
259
260
261
262
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
263
            return ORJSONResponse(
264
265
266
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
267

268
# fastapi implicitly converts json in the request to obj (dataclass)
Ying Sheng's avatar
Ying Sheng committed
269
270
271
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
272

273
@time_func_latency
274
275
276
277
278
279
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
280
        return ORJSONResponse(
281
282
283
284
285
286
287
288
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


289
@time_func_latency
290
async def classify_request(obj: EmbeddingReqInput, request: Request):
291
    """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
292
293
294
295
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
296
        return ORJSONResponse(
297
298
299
300
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


301
302
app.post("/classify")(classify_request)
app.put("/classify")(classify_request)
303
304


305
306
307
##### OpenAI-compatible API endpoints #####


Lianmin Zheng's avatar
Lianmin Zheng committed
308
@app.post("/v1/completions")
309
@time_func_latency
310
311
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
312
313


Cody Yu's avatar
Cody Yu committed
314
@app.post("/v1/chat/completions")
315
@time_func_latency
316
317
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
318

Lianmin Zheng's avatar
Lianmin Zheng committed
319

320
@app.post("/v1/embeddings", response_class=ORJSONResponse)
321
@time_func_latency
Ying Sheng's avatar
Ying Sheng committed
322
323
324
325
326
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


327
@app.get("/v1/models", response_class=ORJSONResponse)
328
329
330
331
332
333
334
335
336
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


337
338
339
340
341
342
343
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


344
345
346
347
348
349
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


350
351
352
353
354
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


355
356
357
358
359
360
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
    # https://platform.openai.com/docs/api-reference/batch/cancel
    return await v1_cancel_batch(tokenizer_manager, batch_id)


361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


378
def launch_engine(
zhyncs's avatar
zhyncs committed
379
380
    server_args: ServerArgs,
):
381
382
383
384
    """
    Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
385
    global tokenizer_manager
386
    global _max_total_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
387

388
    # Configure global environment
389
    configure_logger(server_args)
390
391
    server_args.check_server_args()
    _set_envs_and_config(server_args)
392

393
    # Allocate ports for inter-process communications
394
    port_args = PortArgs.init_new(server_args)
395
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
396

397
398
399
400
    # If using model from www.modelscope.cn, first download the model.
    server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
        server_args.model_path, server_args.tokenizer_path
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
401

402
403
404
405
406
407
408
409
410
411
412
    if server_args.dp_size == 1:
        # Launch tensor parallel scheduler processes
        scheduler_procs = []
        scheduler_pipe_readers = []
        tp_size_per_node = server_args.tp_size // server_args.nnodes
        tp_rank_range = range(
            tp_size_per_node * server_args.node_rank,
            tp_size_per_node * (server_args.node_rank + 1),
        )
        for tp_rank in tp_rank_range:
            reader, writer = mp.Pipe(duplex=False)
413
            gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
            proc = mp.Process(
                target=run_scheduler_process,
                args=(server_args, port_args, gpu_id, tp_rank, None, writer),
            )
            proc.start()
            scheduler_procs.append(proc)
            scheduler_pipe_readers.append(reader)

        if server_args.node_rank >= 1:
            # For other nodes, they do not need to run tokenizer or detokenizer,
            # so they can just wait here.
            while True:
                pass
    else:
        # Launch the data parallel controller
429
        reader, writer = mp.Pipe(duplex=False)
430
        scheduler_pipe_readers = [reader]
431
        proc = mp.Process(
432
433
            target=run_data_parallel_controller_process,
            args=(server_args, port_args, writer),
434
435
        )
        proc.start()
436

437
438
439
    # Launch detokenizer process
    detoken_proc = mp.Process(
        target=run_detokenizer_process,
Lianmin Zheng's avatar
Lianmin Zheng committed
440
441
442
443
444
        args=(
            server_args,
            port_args,
        ),
    )
445
    detoken_proc.start()
Lianmin Zheng's avatar
Lianmin Zheng committed
446

447
    # Launch tokenizer process
448
449
450
451
    tokenizer_manager = TokenizerManager(server_args, port_args)
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)

452
453
    # Wait for model to finish loading & get max token nums
    scheduler_info = []
454
    for i in range(len(scheduler_pipe_readers)):
455
456
457
458
459
460
461
462
463
464
        data = scheduler_pipe_readers[i].recv()

        if data["status"] != "ready":
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
        scheduler_info.append(data)

    # Assume all schedulers have same max_total_num_tokens
    _max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
Lianmin Zheng's avatar
Lianmin Zheng committed
465

Chayenne's avatar
Chayenne committed
466

467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def launch_server(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
    """
    Launch SRT (SGLang Runtime) Server

    The SRT server consists of an HTTP server and the SRT engine.

    1. HTTP server: A FastAPI server that routes requests to the engine.
    2. SRT engine:
        1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
        3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.

    Note:
    1. The HTTP server and Tokenizer Manager both run in the main process.
    2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
    """
    launch_engine(server_args=server_args)

488
489
490
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
491

492
493
494
    # add prometheus middleware
    if server_args.enable_metrics:
        add_prometheus_middleware(app)
495
        enable_func_timer()
496

497
    # Send a warmup request
zhyncs's avatar
zhyncs committed
498
    t = threading.Thread(
Lianmin Zheng's avatar
Lianmin Zheng committed
499
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
zhyncs's avatar
zhyncs committed
500
    )
501
    t.start()
502

503
    try:
504
        # Listen for HTTP requests
505
506
507
        LOGGING_CONFIG["formatters"]["default"][
            "fmt"
        ] = "[%(asctime)s] %(levelprefix)s %(message)s"
508
        LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
509
510
511
        LOGGING_CONFIG["formatters"]["access"][
            "fmt"
        ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
512
        LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
Lianmin Zheng's avatar
Lianmin Zheng committed
513
514
515
516
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
517
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
518
519
520
            timeout_keep_alive=5,
            loop="uvloop",
        )
521
522
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
523

Chayenne's avatar
Chayenne committed
524

525
526
527
528
529
530
async def _get_server_info():
    return {
        **dataclasses.asdict(tokenizer_manager.server_args),  # server args
        "memory_pool_size": await tokenizer_manager.get_memory_pool_size(),  # memory pool size
        "max_total_num_tokens": _max_total_num_tokens,  # max total num tokens
    }
531
532


533
534
535
536
537
538
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
539
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
540

541
542
543
544
    # Set prometheus env vars
    if server_args.enable_metrics:
        set_prometheus_multiproc_dir()

545
546
547
548
549
550
551
552
553
    # Set ulimit
    set_ulimit()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
554
    if server_args.attention_backend == "flashinfer":
555
556
        assert_pkg_version(
            "flashinfer",
557
            "0.1.6",
558
559
560
561
562
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )

563
    mp.set_start_method("spawn", force=True)
564

565

Lianmin Zheng's avatar
Lianmin Zheng committed
566
def _wait_and_warmup(server_args, pipe_finish_writer):
Mingyi's avatar
Mingyi committed
567
568
569
    headers = {}
    url = server_args.url()
    if server_args.api_key:
570
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
571
572

    # Wait until the server is launched
573
    success = False
Mingyi's avatar
Mingyi committed
574
    for _ in range(120):
575
        time.sleep(1)
Mingyi's avatar
Mingyi committed
576
        try:
577
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
578
            assert res.status_code == 200, f"{res=}, {res.text=}"
579
            success = True
Mingyi's avatar
Mingyi committed
580
            break
581
        except (AssertionError, requests.exceptions.RequestException):
582
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
583
584
            pass

585
586
587
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
588
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
Lianmin Zheng's avatar
Lianmin Zheng committed
589
        kill_child_process(include_self=True)
590
        return
591

592
    model_info = res.json()
593

Mingyi's avatar
Mingyi committed
594
    # Send a warmup request
595
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
596
    max_new_tokens = 8 if model_info["is_generation"] else 1
597
598
599
600
601
602
603
604
605
606
607
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
608
609
610
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
611
                url + request_name,
612
                json=json_data,
Mingyi's avatar
Mingyi committed
613
614
615
                headers=headers,
                timeout=600,
            )
616
            assert res.status_code == 200, f"{res}"
617
    except Exception:
618
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
619
        if pipe_finish_writer is not None:
620
            pipe_finish_writer.send(last_traceback)
621
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
Lianmin Zheng's avatar
Lianmin Zheng committed
622
        kill_child_process(include_self=True)
623
        return
Mingyi's avatar
Mingyi committed
624

625
626
    # logger.info(f"{res.json()=}")

Mingyi's avatar
Mingyi committed
627
628
    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
629
        pipe_finish_writer.send("ready")
Mingyi's avatar
Mingyi committed
630

631
632
633
    if server_args.delete_ckpt_after_loading:
        delete_directory(server_args.model_path)

Mingyi's avatar
Mingyi committed
634

Lianmin Zheng's avatar
Lianmin Zheng committed
635
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
636
637
638
639
640
641
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
642
643
    def __init__(
        self,
644
        log_level: str = "error",
Lianmin Zheng's avatar
Lianmin Zheng committed
645
646
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
647
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
648
        """See the arguments in server_args.py::ServerArgs"""
649
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
650

651
652
653
        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)

Lianmin Zheng's avatar
Lianmin Zheng committed
654
        # Pre-allocate ports
655
656
657
658
659
        for port in range(10000, 40000):
            if is_port_available(port):
                break
            port += 1
        self.server_args.port = port
Lianmin Zheng's avatar
Lianmin Zheng committed
660

Ying Sheng's avatar
Ying Sheng committed
661
        self.url = self.server_args.url()
662
        self.generate_url = self.url + "/generate"
Lianmin Zheng's avatar
Lianmin Zheng committed
663

664
        # NOTE: We store pid instead of proc to fix some issues during __delete__
Lianmin Zheng's avatar
Lianmin Zheng committed
665
666
        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
667

Yuanhan Zhang's avatar
Yuanhan Zhang committed
668
669
        proc = mp.Process(
            target=launch_server,
Lianmin Zheng's avatar
Lianmin Zheng committed
670
            args=(self.server_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
671
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
672
        proc.start()
673
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
674
675
        self.pid = proc.pid

676
677
678
679
680
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

681
        if init_state != "ready":
Lianmin Zheng's avatar
Lianmin Zheng committed
682
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
683
684
685
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
686
687
688
689
690

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
691
            kill_child_process(self.pid, include_self=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
692
693
            self.pid = None

694
695
696
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
697
698
699
700
701
702
703
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

704
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
705
706
        self,
        prompt: str,
707
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
708
    ):
709
710
711
712
713
714
715
716
717
718
719
720
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
721
722
723
724
725
726
727
728
729
730
731
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
Lianmin Zheng's avatar
Lianmin Zheng committed
732
                        if "text" in data:
733
734
735
736
737
738
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
739

740
741
742
743
    add_request = async_generate

    def generate(
        self,
744
        prompt: Union[str, List[str]],
745
746
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
747
        logprob_start_len: Optional[Union[List[int], int]] = None,
748
        top_logprobs_num: Optional[Union[List[int], int]] = None,
749
        lora_path: Optional[List[Optional[str]]] = None,
750
751
752
753
754
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
755
            "logprob_start_len": logprob_start_len,
756
            "top_logprobs_num": top_logprobs_num,
757
            "lora_path": lora_path,
758
        }
759
        assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
760
761
762
763
764
765
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

766
767
    def encode(
        self,
768
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
769
    ):
770
771
        json_data = {"text": prompt}
        response = requests.post(self.url + "/encode", json=json_data)
772
773
        return json.dumps(response.json())

774
775
776
777
778
779
780
781
782
783
    async def get_server_info(self):
        async with aiohttp.ClientSession() as session:
            async with session.get(f"{self.url}/get_server_info") as response:
                if response.status == 200:
                    return await response.json()
                else:
                    error_data = await response.json()
                    raise RuntimeError(
                        f"Failed to get server info. {error_data['error']['message']}"
                    )
784

Lianmin Zheng's avatar
Lianmin Zheng committed
785
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
786
        self.shutdown()
787
788


Lianmin Zheng's avatar
Lianmin Zheng committed
789
790
791
792
STREAM_END_SYMBOL = b"data: [DONE]"
STREAM_CHUNK_START_SYMBOL = b"data:"


793
794
795
796
797
798
799
800
801
802
803
804
class Engine:
    """
    SRT Engine without an HTTP server layer.

    This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
    launching the HTTP server adds unnecessary complexity or overhead,
    """

    def __init__(self, *args, **kwargs):

        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)
805

Chayenne's avatar
Chayenne committed
806
807
808
        # runtime server default log level is log
        # offline engine works in scripts, so we set it to error

Chayenne's avatar
Chayenne committed
809
810
        if "log_level" not in kwargs:
            kwargs["log_level"] = "error"
811
812
813
814
815
816

        server_args = ServerArgs(*args, **kwargs)
        launch_engine(server_args=server_args)

    def generate(
        self,
817
818
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
819
        sampling_params: Optional[Union[List[Dict], Dict]] = None,
820
821
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
822
823
824
825
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
826
        stream: bool = False,
827
828
829
    ):
        obj = GenerateReqInput(
            text=prompt,
830
            input_ids=input_ids,
831
832
833
834
835
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
836
            stream=stream,
837
838
        )

839
840
        # get the current event loop
        loop = asyncio.get_event_loop()
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
        ret = loop.run_until_complete(generate_request(obj, None))

        if stream is True:

            def generator_wrapper():
                offset = 0
                loop = asyncio.get_event_loop()
                generator = ret.body_iterator
                while True:
                    chunk = loop.run_until_complete(generator.__anext__())

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
            # however, it allows to wrap the generator as a subfunction and return
            return generator_wrapper()
        else:
            return ret

    async def async_generate(
        self,
868
869
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
870
        sampling_params: Optional[Dict] = None,
871
872
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
873
874
875
876
877
878
879
880
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
        stream: bool = False,
    ):
        obj = GenerateReqInput(
            text=prompt,
881
            input_ids=input_ids,
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
            stream=stream,
        )

        ret = await generate_request(obj, None)

        if stream is True:
            generator = ret.body_iterator

            async def generator_wrapper():

                offset = 0

                while True:
                    chunk = await generator.__anext__()

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            return generator_wrapper()
        else:
            return ret
913
914

    def shutdown(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
915
        kill_child_process()
916

917
918
919
920
921
922
923
924
    def get_tokenizer(self):
        global tokenizer_manager

        if tokenizer_manager is None:
            raise ReferenceError("Tokenizer Manager is not initialized.")
        else:
            return tokenizer_manager.tokenizer

James Xu's avatar
James Xu committed
925
926
927
928
929
930
931
932
933
    def encode(
        self,
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
    ):
        obj = EmbeddingReqInput(text=prompt)

        # get the current event loop
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(encode_request(obj, None))
934

935
936
    async def get_server_info(self):
        return await _get_server_info()